diff --git a/.typos.toml b/.typos.toml index 66edbcbaf..253ed3385 100644 --- a/.typos.toml +++ b/.typos.toml @@ -1,4 +1,5 @@ [default.extend-words] numer = "numer" Commun = "Commun" -nd = "nd" \ No newline at end of file +nd = "nd" +assum = "assum" diff --git a/Project.toml b/Project.toml index b4829c368..fcfe9ab1c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Symbolics" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" authors = ["Shashi Gowda "] -version = "6.11.0" +version = "6.13.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -93,7 +93,7 @@ StaticArraysCore = "1.4" SymPy = "2.2" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.2" -SymbolicUtils = "2, 3" +SymbolicUtils = "3.7" TermInterface = "2" julia = "1.10" diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index ebf5174a5..66d069060 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -320,13 +320,9 @@ end # Helps with precompilation time PrecompileTools.@setup_workload begin @variables a b c x y z - equation1 = a*log(x)^b + c ~ 0 - equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3 simple_linear_equations = [x - y, y + 2z] equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z] PrecompileTools.@compile_workload begin - symbolic_solve(equation1, x) - symbolic_solve(equation_actually_polynomial) symbolic_solve(simple_linear_equations, [x, y], warns=false) symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false) end diff --git a/ext/SymbolicsNemoExt.jl b/ext/SymbolicsNemoExt.jl index 16fe2414e..9c9f31db1 100644 --- a/ext/SymbolicsNemoExt.jl +++ b/ext/SymbolicsNemoExt.jl @@ -61,7 +61,13 @@ end PrecompileTools.@setup_workload begin @variables a b c x y z expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) + equation1 = a*log(x)^b + c ~ 0 + equation_polynomial = 9^x + 3^x + 2 + exp_eq = 5*2^(x+1) + 7^(x+3) PrecompileTools.@compile_workload begin + symbolic_solve(equation1, x) + symbolic_solve(equation_polynomial, x) + symbolic_solve(exp_eq) symbolic_solve(expr_with_params, x, dropmultiplicity=false) symbolic_solve(x^10 - a^10, x, dropmultiplicity=false) end diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 008cc8ec2..5659e42ea 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -92,6 +92,7 @@ export Inequality, ≲, ≳ include("inequality.jl") import Bijections, DynamicPolynomials +export tosymbol include("utils.jl") using ConstructionBase diff --git a/src/arrays.jl b/src/arrays.jl index 6bad21582..558dd754c 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -63,6 +63,14 @@ end ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m) + args = map(args) do arg + if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic() + return Ref(only(arguments(arg))) + else + return arg + end + end + t = f(args...) t isa Symbolic && !isnothing(m) ? metadata(t, m) : t @@ -968,7 +976,7 @@ end ### Codegen function SymbolicUtils.Code.toexpr(x::ArrayOp, st) - haskey(st.symbolify, x) && return st.symbolify[x] + haskey(st.rewrites, x) && return st.rewrites[x] if iscall(x.term) toexpr(x.term, st) diff --git a/src/solver/attract.jl b/src/solver/attract.jl index 027f85a99..6d03778e8 100644 --- a/src/solver/attract.jl +++ b/src/solver/attract.jl @@ -197,10 +197,8 @@ function attract_trig(lhs, var) r_trig = [@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2=>one(~x)) @acrule(sin(~x::(contains_var))^2 + -1=>-1 * cos(~x)^2) @acrule(cos(~x::(contains_var))^2 + -1=>-1 * sin(~x)^2) - @acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2 * - ~x)) - @acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2 * - ~x)) + @acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2*~x)) + @acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2*~x)) @acrule(cos(~x::(contains_var)) * sin(~x::(contains_var))=>sin(2 * ~x) / 2) @acrule(tan(~x::(contains_var))^2 + -1 * sec(~x::(contains_var))^2=>one(~x)) @acrule(-1 * tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2=>one(~x)) diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index c1f998a4f..8f83ca02b 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -123,7 +123,7 @@ function isolate(lhs, var; warns=true, conditions=[]) new_var = (@variables $new_var)[1] rhs = map( sol -> term(rev_oper[oper], sol) + - term(*, Base.MathConstants.pi, 2 * new_var), + term(*, Base.MathConstants.pi, new_var), rhs) @info string(new_var) * " ϵ" * " Ζ" diff --git a/src/solver/main.jl b/src/solver/main.jl index e9d3f3aeb..7254ea152 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -195,7 +195,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where for e in expr for var in x if !check_poly_inunivar(e, var) - warns && @warn("This system can not be currently solved by solve.") + warns && @warn("This system can not be currently solved by `symbolic_solve`.") return nothing end end @@ -276,7 +276,7 @@ function solve_univar(expression, x; dropmultiplicity=true) end end - subs, filtered_expr = filter_poly(expression, x) + subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true) coeffs, constant = polynomial_coeffs(filtered_expr, [x]) degree = sdegree(coeffs, x) @@ -296,18 +296,28 @@ function solve_univar(expression, x; dropmultiplicity=true) append!(arr_roots, og_arr_roots) end end - - return arr_roots end if length(factors) != 1 - for factor in factors_subbed - roots = solve_univar(factor, x, dropmultiplicity = dropmultiplicity) + for i in eachindex(factors_subbed) + if !any(isequal(x, var) for var in get_variables(factors[i])) + continue + end + roots = solve_univar(factors_subbed[i], x, dropmultiplicity = dropmultiplicity) append!(arr_roots, roots) end end + for i in reverse(eachindex(arr_roots)) + for j in eachindex(assumptions) + if isequal(substitute(assumptions[j], Dict(x=>arr_roots[i])), 0) + deleteat!(arr_roots, i) + end + end + end + if isequal(arr_roots, []) + @assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`." return [RootsOf(wrap(expression), wrap(x))] end diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index ff72fdf3f..4764690aa 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -1,4 +1,3 @@ - # Alex: make sure `Num`s are not processed here as they'd break it. _postprocess_root(x) = x @@ -32,12 +31,12 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) !iscall(x) && return x x = Symbolics.term(operation(x), map(_postprocess_root, arguments(x))...) + oper = operation(x) # sqrt(0), cbrt(0) => 0 # sqrt(1), cbrt(1) => 1 - if iscall(x) && - (operation(x) === sqrt || operation(x) === cbrt || operation(x) === ssqrt || - operation(x) === scbrt) + if (oper === sqrt || oper === cbrt || oper === ssqrt || + oper === scbrt) arg = arguments(x)[1] if isequal(arg, 0) || isequal(arg, 1) return arg @@ -45,17 +44,17 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end # (X)^0 => 1 - if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 0) + if oper === (^) && isequal(arguments(x)[2], 0) return 1 end # (X)^1 => X - if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 1) + if oper === (^) && isequal(arguments(x)[2], 1) return arguments(x)[1] end # sqrt((N / D)^2 * M) => N / D * sqrt(M) - if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt) + if (oper === sqrt || oper === ssqrt) function squarefree_decomp(x::Integer) square, squarefree = big(1), big(1) for (p, d) in collect(Primes.factor(abs(x))) @@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end # (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2)) - if iscall(x) && operation(x) === (^) + if oper === (^) arg1, arg2 = arguments(x) if iscall(arg1) && (operation(arg1) === sqrt || operation(arg1) === ssqrt) if arg2 isa Integer @@ -105,6 +104,19 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end + x = convert_consts(x) + + if oper === (+) + args = arguments(x) + for arg in args + if isequal(arg, 0) + after_removing = setdiff(args, arg) + isone(length(after_removing)) && return after_removing[1] + return Symbolics.term(+, after_removing) + end + end + end + return x end @@ -122,3 +134,54 @@ function postprocess_root(x) end x # unreachable end + + +inv_exacts = [0, Symbolics.term(*, pi), + Symbolics.term(/,pi,3), + Symbolics.term(/, pi, 2), + Symbolics.term(/, Symbolics.term(*, 2, pi), 3), + Symbolics.term(/, pi, 6), + Symbolics.term(/, Symbolics.term(*, 5, pi), 6), + Symbolics.term(/, pi, 4) +] +inv_evald = Symbolics.symbolic_to_float.(inv_exacts) + +const inv_pairs = collect(zip(inv_exacts, inv_evald)) +""" + function convert_consts(x) +This function takes BasicSymbolic terms as input (x) and attempts +to simplify these basic symbolic terms using known values. +Currently, this function only supports inverse trigonometric functions. + +## Examples +```jldoctest +julia> Symbolics.convert_consts(Symbolics.term(acos, 0)) +π / 2 + +julia> Symbolics.convert_consts(Symbolics.term(atan, 0)) +0 + +julia> Symbolics.convert_consts(Symbolics.term(atan, 1)) +π / 4 +``` +""" +function convert_consts(x) + !iscall(x) && return x + + oper = operation(x) + inv_opers = [asin, acos, atan] + + if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) + val = Symbolics.symbolic_to_float(x) + for (exact, evald) in inv_pairs + if isapprox(evald, val) + return exact + elseif isapprox(-evald, val) + return -exact + end + end + end + + # add [sin, cos, tan] simplifications in the future? + return x +end diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index cd586edb3..8b0dc5f81 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -40,16 +40,19 @@ function clean_f(filtered_expr, var, subs) unwrapped_f = unwrap(filtered_expr) !iscall(unwrapped_f) && return filtered_expr oper = operation(unwrapped_f) + assumptions = [] if oper === (/) args = arguments(unwrapped_f) if any(isequal(var, x) for x in get_variables(args[2])) - return filtered_expr + filtered_expr = expand(args[1] * args[2]) + push!(assumptions, substitute(args[2], subs, fold=false)) + return filtered_expr, assumptions end filtered_expr = args[1] @info "Assuming $(substitute(args[2], subs, fold=false) != 0)" end - return filtered_expr + return filtered_expr, assumptions end """ @@ -238,15 +241,17 @@ julia> filter_poly((x+1)*term(log, 3), x) (Dict{Any, Any}(var"##247" => log(3)), var"##247"*(1 + x)) ``` """ -function filter_poly(og_expr, var) +function filter_poly(og_expr, var; assumptions=false) expr = deepcopy(og_expr) expr = unwrap(expr) vars = get_variables(expr) # handle edge cases if !isequal(vars, []) && isequal(vars[1], expr) + assumptions && return Dict{Any, Any}(), expr, [] return (Dict{Any, Any}(), expr) elseif isequal(vars, []) + assumptions && return filter_stuff(expr), [] return filter_stuff(expr) end @@ -256,14 +261,16 @@ function filter_poly(og_expr, var) # reassemble expr to avoid variables remembering original values issue and clean args = arguments(expr) oper = operation(expr) - new_expr = clean_f(term(oper, args...), var, subs) + new_expr, assum_array = clean_f(term(oper, args...), var, subs) + assumptions && return subs, new_expr, assum_array return subs, new_expr end -function filter_poly(og_expr) + +function filter_poly(og_expr; assumptions=false) new_var = gensym() new_var = (@variables $(new_var))[1] - return filter_poly(og_expr, new_var) + return filter_poly(og_expr, new_var; assumptions=assumptions) end diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f2420969d..7496f65f6 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -78,7 +78,7 @@ function check_expr_validity(expr) valid_type = false if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} || - type_expr == Complex{Num} || type_expr == ComplexTerm{Real} + type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}} valid_type = true end iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported" diff --git a/src/utils.jl b/src/utils.jl index e0a2aaaf8..ce5c88fbf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -131,8 +131,8 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi string(nameof(arguments(oldop)[1])) elseif oldop == getindex args = arguments(O) - opname = string(tosymbol(args[1]), "[", map(tosymbol, args[2:end])..., "]") - return _Sym(symtype(O), Symbol(opname, d_separator, ds)) + opname = string(tosymbol(args[1])) + return metadata(_Sym(symtype(args[1]), Symbol(opname, d_separator, ds)), metadata(args[1]))[args[2:end]...] elseif oldop isa Function return nothing else diff --git a/src/variable.jl b/src/variable.jl index 30421c285..528619e67 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -101,6 +101,11 @@ function _parse_vars(macroname, type, x, transform=identity) # x = 1, [connect = flow; unit = u"m^3/s"] if Meta.isexpr(v, :(=)) v, val = v.args + # defaults with metadata for function variables + if Meta.isexpr(val, :block) + Base.remove_linenums!(val) + val = only(val.args) + end if Meta.isexpr(val, :tuple) && length(val.args) == 2 && isoption(val.args[2]) options = val.args[2].args val = val.args[1] @@ -124,7 +129,7 @@ function _parse_vars(macroname, type, x, transform=identity) isruntime, v = unwrap_runtime_var(v) iscall = Meta.isexpr(v, :call) isarray = Meta.isexpr(v, :ref) - if iscall && Meta.isexpr(v.args[1], :ref) + if iscall && Meta.isexpr(v.args[1], :ref) && !call_args_are_function(map(last∘unwrap_runtime_var, @view v.args[2:end])) @warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead. The former creates an array of functions, while the latter creates an array valued function. The deprecated syntax will cause an error in the next major release of Symbolics. @@ -155,35 +160,61 @@ function _parse_vars(macroname, type, x, transform=identity) return ex end +call_args_are_function(_) = false +function call_args_are_function(call_args::AbstractArray) + !isempty(call_args) && (call_args[end] == :(..) || all(Base.Fix2(Meta.isexpr, :(::)), call_args)) +end + function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime) ndim = :($length(($(indices...),))) - vname = !isruntime ? Meta.quot(lhs) : lhs - if call_args[1] == :.. - ex = :($CallWithMetadata($_Sym($FnType{Tuple, Array{$type, $ndim}}, $vname))) + if call_args_are_function(call_args) + vname, fntype = function_name_and_type(lhs) + # name was already unwrapped before calling this function and is of the form $x + if isruntime + _vname = vname + else + # either no ::fnType or $x::fnType + vname, fntype = function_name_and_type(lhs) + isruntime, vname = unwrap_runtime_var(vname) + if isruntime + _vname = vname + else + _vname = Meta.quot(vname) + end + end + argtypes = arg_types_from_call_args(call_args) + ex = :($CallWithMetadata($_Sym($FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}, $_vname))) else - ex = :($_Sym($FnType{Tuple, Array{$type, $ndim}}, $vname)(map($unwrap, ($(call_args...),))...)) + vname = lhs + if isruntime + _vname = vname + else + _vname = Meta.quot(vname) + end + ex = :($_Sym($FnType{Tuple, Array{$type, $ndim}}, $_vname)(map($unwrap, ($(call_args...),))...)) end ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) if val !== nothing ex = :($setdefaultval($ex, $val)) end - ex = setprops_expr(ex, prop, macroname, Meta.quot(lhs)) + ex = setprops_expr(ex, prop, macroname, Meta.quot(vname)) #ex = :($scalarize_getindex($ex)) ex = :($wrap($ex)) ex = :($transform($ex)) if isruntime - lhs = gensym(lhs) + vname = gensym(Symbol(vname)) end - lhs, :($lhs = $ex) + vname, :($vname = $ex) end function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime) issym = v isa Symbol - isarray = isa(v, Expr) && v.head == :ref + isarray = !isruntime && Meta.isexpr(v, :ref) if isarray + # this can't be an array of functions, since that was handled by `construct_dep_array_vars` var_name = v.args[1] if Meta.isexpr(var_name, :(::)) var_name, type′ = var_name.args @@ -192,6 +223,22 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr isruntime, var_name = unwrap_runtime_var(var_name) indices = v.args[2:end] expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...) + elseif call_args_are_function(call_args) + var_name, fntype = function_name_and_type(v) + # name was already unwrapped before calling this function and is of the form $x + if isruntime + vname = var_name + else + # either no ::fnType or $x::fnType + var_name, fntype = function_name_and_type(v) + isruntime, var_name = unwrap_runtime_var(var_name) + if isruntime + vname = var_name + else + vname = Meta.quot(var_name) + end + end + expr = construct_var(macroname, fntype == () ? vname : Expr(:(::), vname, fntype[1]), type, call_args, val, prop) else var_name = v if Meta.isexpr(v, :(::)) @@ -200,7 +247,7 @@ function construct_vars(macroname, v, type, call_args, val, prop, transform, isr end expr = construct_var(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop) end - lhs = isruntime ? gensym(var_name) : var_name + lhs = isruntime ? gensym(Symbol(var_name)) : var_name rhs = :($transform($expr)) lhs, :($lhs = $rhs) end @@ -249,15 +296,60 @@ function Base.show(io::IO, c::CallWithMetadata) print(io, "⋆") end +struct CallWithParent end + function (f::CallWithMetadata)(args...) - metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)) + setmetadata(metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)), CallWithParent, f) +end + +Base.isequal(a::CallWithMetadata, b::CallWithMetadata) = isequal(a.f, b.f) + +function arg_types_from_call_args(call_args) + if length(call_args) == 1 && only(call_args) == :.. + return Tuple + end + Ts = map(call_args) do arg + if arg == :.. + Vararg + elseif arg isa Expr && arg.head == :(::) + if length(arg.args) == 1 + arg.args[1] + elseif arg.args[1] == :.. + :(Vararg{$(arg.args[2])}) + else + arg.args[2] + end + else + error("Invalid call argument $arg") + end + end + return :(Tuple{$(Ts...)}) +end + +function function_name_and_type(var_name) + if var_name isa Expr && Meta.isexpr(var_name, :(::), 2) + var_name.args[1], (var_name.args[2],) + else + var_name, () + end end function construct_var(macroname, var_name, type, call_args, val, prop) expr = if call_args === nothing +<<<<<<< HEAD :($_Sym($type, $var_name)) elseif !isempty(call_args) && call_args[end] == :.. :($CallWithMetadata($_Sym($FnType{Tuple, $type}, $var_name))) +======= + :($Sym{$type}($var_name)) + elseif call_args_are_function(call_args) + # function syntax is (x::TFunc)(.. or ::TArg1, ::TArg2)::TRet + # .. is Vararg + # (..)::ArgT is Vararg{ArgT} + var_name, fntype = function_name_and_type(var_name) + argtypes = arg_types_from_call_args(call_args) + :($CallWithMetadata($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name))) +>>>>>>> origin/master else :($_Sym($FnType{NTuple{$(length(call_args)), Any}, $type}, $var_name)($(map(x->:($value($x)), call_args)...))) end @@ -283,9 +375,15 @@ function _construct_array_vars(macroname, var_name, type, call_args, val, prop, expr = if call_args === nothing ex = :($_Sym(Array{$type, $ndim}, $var_name)) :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - elseif !isempty(call_args) && call_args[end] == :.. + elseif call_args_are_function(call_args) need_scalarize = true +<<<<<<< HEAD ex = :($_Sym(Array{$FnType{Tuple, $type}, $ndim}, $var_name)) +======= + var_name, fntype = function_name_and_type(var_name) + argtypes = arg_types_from_call_args(call_args) + ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name)) +>>>>>>> origin/master ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) :($map($CallWithMetadata, $ex)) else diff --git a/test/arrays.jl b/test/arrays.jl index aeed3ba90..bf164cb36 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -36,7 +36,9 @@ end @testset "getname" begin @variables t x(t)[1:4] v = Symbolics.lower_varname(unwrap(x[2]), unwrap(t), 2) - @test getname(v) == Symbol("x(t)[2]ˍtt") + @test operation(v) == getindex + @test arguments(v)[2] == 2 + @test getname(v) == getname(arguments(v)[1]) == Symbol("x(t)ˍtt") end @testset "getindex" begin @@ -80,6 +82,8 @@ end @test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), nothing)) T2 = unwrap(3B) @test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], nothing)) + T3 = unwrap(A .^ 2) + @test isequal(T3, Symbolics.maketerm(typeof(T3), operation(T3), arguments(T3), nothing)) end getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) diff --git a/test/build_function.jl b/test/build_function.jl index 18e0a574b..5e7266402 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -1,7 +1,7 @@ using Symbolics, SparseArrays, LinearAlgebra, Test using ReferenceTests using Symbolics: value -using SymbolicUtils.Code: DestructuredArgs, Func +using SymbolicUtils.Code: DestructuredArgs, Func, NameState @variables a b c1 c2 c3 d e g oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true) @test all(isnan, eval(oop)([-1, Inf])) @@ -275,3 +275,9 @@ let #658 k = eval(build_function(a * X1 + X2, X1, X2, a)[2]) @test k(ones(3), ones(3), 1.5) == [2.5, 2.5, 2.5] end + +@testset "ArrayOp codegen" begin + @variables x[1:2] + T = value(x .^ 2) + @test_nowarn toexpr(T, NameState()) +end diff --git a/test/diff.jl b/test/diff.jl index eb500cb9d..0431933b6 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -13,7 +13,7 @@ Dx = Differential(x) @test Symbol(D(D(uu))) === Symbol("uuˍtt(t)") @test Symbol(D(uuˍt)) === Symbol(D(D(uu))) -@test Symbol(D(v[2])) === Symbol("v(t)[2]ˍt") +@test Symbol(D(v[2])) === Symbol("getindex(var\"v(t)ˍt\", 2)") test_equal(a, b) = @test isequal(simplify(a), simplify(b)) diff --git a/test/macro.jl b/test/macro.jl index e9e64478f..be21ce72c 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -1,5 +1,5 @@ using Symbolics -import Symbolics: getsource, getdefaultval, wrap, unwrap, getname +import Symbolics: CallWithMetadata, getsource, getdefaultval, wrap, unwrap, getname import SymbolicUtils: symtype, FnType, BasicSymbolic, promote_symtype, _Term using LinearAlgebra using Test @@ -238,3 +238,136 @@ spam(x) = 2x sym = spam([a, 2a]) @test sym isa Num @test unwrap(sym) isa BasicSymbolic{Real} + +fn_defaults = [print, min, max, identity, (+), (-), max, sum, vcat, (*)] +fn_names = [Symbol(:f, i) for i in 1:10] + +struct VariableFoo end +Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo + +function test_all_functions(fns) + f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = fns + @variables x y::Int z::Function w[1:3, 1:3] v[1:3, 1:3]::String + @test f1 isa CallWithMetadata{FnType{Tuple, Real}} + @test all(x -> symtype(x) <: Real, [f1(), f1(1), f1(x), f1(x, y), f1(x, y, x+y)]) + @test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int}} + @test all(x -> symtype(x) <: Int, [f2(1), f2(z), f2(x), f2(x, y), f2(x, y, x+y)]) + @test_throws ErrorException f2() + @test f3 isa CallWithMetadata{FnType{Tuple, Real, typeof(max)}} + @test all(x -> symtype(x) <: Real, [f3(), f3(1), f3(x), f3(x, y), f3(x, y, x+y)]) + @test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real}} + @test all(x -> symtype(x) <: Real, [f4(1), f4(y), f4(2y)]) + @test_throws ErrorException f4(x) + @test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real}} + @test all(x -> symtype(x) <: Real, [f5(1), f5(y), f5(y, y), f5(2, 3)]) + @test_throws ErrorException f5(x) + @test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int}} + @test all(x -> symtype(x) <: Int, [f6(1, 1), f6(y, y), f6(1, y), f6(y, 1)]) + @test_throws ErrorException f6() + @test_throws ErrorException f6(1) + @test_throws ErrorException f6(x, y) + @test_throws ErrorException f6(y) + @test f7 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, typeof(max)}} + # call behavior tested by f6 + @test f8 isa CallWithMetadata{FnType{Tuple{Function, Vararg}, Real, typeof(sum)}} + @test all(x -> symtype(x) <: Real, [f8(z), f8(z, x), f8(identity), f8(identity, x)]) + @test_throws ErrorException f8(x) + @test_throws ErrorException f8(1) + @test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}}} + @test all(x -> symtype(unwrap(x)) <: Vector{Real} && size(x) == (3,), [f9(), f9(1), f9(x), f9(x + y), f9(z), f9(1, x)]) + @test f10 isa CallWithMetadata{FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)}} + @test all(x -> symtype(unwrap(x)) <: Matrix{Real} && size(x) == (3, 3), [f10(w, w), f10(w, ones(3, 3)), f10(ones(3, 3), ones(3, 3)), f10(w + w, w)]) + @test_throws ErrorException f10(w, v) +end + +function test_functions_defaults(fns) + for (fn, def) in zip(fns, fn_defaults) + @test Symbolics.getdefaultval(fn, nothing) == def + end +end + +function test_functions_metadata(fns) + for (i, fn) in enumerate(fns) + @test Symbolics.getmetadata(fn, VariableFoo, nothing) == i + end +end + +fns = @test_nowarn @variables begin + f1(..) + f2(::Any, ..)::Int + (f3::typeof(max))(..) + f4(::Int) + f5(::Int, (..)::Int) + f6(::Int, ::Int)::Int + (f7::typeof(max))(::Int, ::Int)::Int + (f8::typeof(sum))(::Function, ..) + f9(..)[1:3] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] + # f11[1:3](::Int)::Int +end + +test_all_functions(fns) + +fns = @test_nowarn @variables begin + f1(..) = fn_defaults[1] + f2(::Any, ..)::Int = fn_defaults[2] + (f3::typeof(max))(..) = fn_defaults[3] + f4(::Int) = fn_defaults[4] + f5(::Int, (..)::Int) = fn_defaults[5] + f6(::Int, ::Int)::Int = fn_defaults[6] + (f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7] + (f8::typeof(sum))(::Function, ..) = fn_defaults[8] + f9(..)[1:3] = fn_defaults[9] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10] +end + +test_all_functions(fns) +test_functions_defaults(fns) + +fns = @variables begin + f1(..) = fn_defaults[1], [foo = 1] + f2(::Any, ..)::Int = fn_defaults[2], [foo = 2;] + (f3::typeof(max))(..) = fn_defaults[3], [foo = 3;] + f4(::Int) = fn_defaults[4], [foo = 4;] + f5(::Int, (..)::Int) = fn_defaults[5], [foo = 5;] + f6(::Int, ::Int)::Int = fn_defaults[6], [foo = 6;] + (f7::typeof(max))(::Int, ::Int)::Int = fn_defaults[7], [foo = 7;] + (f8::typeof(sum))(::Function, ..) = fn_defaults[8], [foo = 8;] + f9(..)[1:3] = fn_defaults[9], [foo = 9;] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] = fn_defaults[10], [foo = 10;] +end + +test_all_functions(fns) +test_functions_defaults(fns) +test_functions_metadata(fns) + +fns = @test_nowarn @variables begin + f1(..), [foo = 1,] + f2(::Any, ..)::Int, [foo = 2,] + (f3::typeof(max))(..), [foo = 3,] + f4(::Int), [foo = 4,] + f5(::Int, (..)::Int), [foo = 5,] + f6(::Int, ::Int)::Int, [foo = 6,] + (f7::typeof(max))(::Int, ::Int)::Int, [foo = 7,] + (f8::typeof(sum))(::Function, ..), [foo = 8,] + f9(..)[1:3], [foo = 9,] + (f10::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3], [foo = 10,] +end + +test_all_functions(fns) +test_functions_metadata(fns) + +fns = @test_nowarn @variables begin + $(fn_names[1])(..) + $(fn_names[2])(::Any, ..)::Int + ($(fn_names[3])::typeof(max))(..) + $(fn_names[4])(::Int) + $(fn_names[5])(::Int, (..)::Int) + $(fn_names[6])(::Int, ::Int)::Int + ($(fn_names[7])::typeof(max))(::Int, ::Int)::Int + ($(fn_names[8])::typeof(sum))(::Function, ..) + $(fn_names[9])(..)[1:3] + ($(fn_names[10])::typeof(*))(::Matrix{<:Real}, ::Matrix{<:Real})[1:3, 1:3] +end + +test_all_functions(fns) diff --git a/test/semipoly.jl b/test/semipoly.jl index 5a13113ec..364f4c6c3 100644 --- a/test/semipoly.jl +++ b/test/semipoly.jl @@ -426,6 +426,10 @@ end const components = [2, a, b, c, x, y, z, (1+x), (1+y)^2, z*y, z*x] +function verify(t::Symbolics.BasicSymbolic{Number}, d, wrt, nl) + verify(Num(t), d, wrt, nl) +end + function verify(t, d, wrt, nl) try iszero(t - (isempty(d) ? nl : sum(k*v for (k, v) in d) + nl)) @@ -505,3 +509,7 @@ for i=1:20 trial() end end + +@testset "Extracted from fuzz testing" begin + @test verify(2.25(2.0 + 2c)*(c^2), Dict{Any, Any}(c^3 => 4.5, c^2 => 4.5), Num[c, y, z], 0) +end diff --git a/test/solver.jl b/test/solver.jl index 7a39c8924..58e5feb26 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -57,9 +57,14 @@ end @variables x y z a b c d e @testset "Invalid input" begin - @test_throws AssertionError Symbolics.get_roots(x, x^2) - @test_throws AssertionError Symbolics.get_roots(x^3 + sin(x), x) - @test_throws AssertionError Symbolics.get_roots(1/x, x) + @test_throws AssertionError symbolic_solve(x, x^2) + @test_throws AssertionError symbolic_solve(1/x, x) +end + +@testset "Nice univar cases" begin + found_roots = symbolic_solve(1/x^2 ~ 1/y^2 - 2/x^3 * (x-y), x) + known_roots = Symbolics.unwrap.([y, -2y]) + @test isequal(found_roots, known_roots) end @testset "Deg 1 univar" begin