Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into migrate-to-Expronicon
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenszhu committed Sep 27, 2024
2 parents 79f524d + b3ce057 commit da3259e
Show file tree
Hide file tree
Showing 20 changed files with 399 additions and 55 deletions.
3 changes: 2 additions & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[default.extend-words]
numer = "numer"
Commun = "Commun"
nd = "nd"
nd = "nd"
assum = "assum"
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <gowda@mit.edu>"]
version = "6.11.0"
version = "6.13.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -93,7 +93,7 @@ StaticArraysCore = "1.4"
SymPy = "2.2"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.2"
SymbolicUtils = "2, 3"
SymbolicUtils = "3.7"
TermInterface = "2"
julia = "1.10"

Expand Down
4 changes: 0 additions & 4 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,9 @@ end
# Helps with precompilation time
PrecompileTools.@setup_workload begin
@variables a b c x y z
equation1 = a*log(x)^b + c ~ 0
equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3
simple_linear_equations = [x - y, y + 2z]
equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z]
PrecompileTools.@compile_workload begin
symbolic_solve(equation1, x)
symbolic_solve(equation_actually_polynomial)
symbolic_solve(simple_linear_equations, [x, y], warns=false)
symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false)
end
Expand Down
6 changes: 6 additions & 0 deletions ext/SymbolicsNemoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ end
PrecompileTools.@setup_workload begin
@variables a b c x y z
expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a))
equation1 = a*log(x)^b + c ~ 0
equation_polynomial = 9^x + 3^x + 2
exp_eq = 5*2^(x+1) + 7^(x+3)
PrecompileTools.@compile_workload begin
symbolic_solve(equation1, x)
symbolic_solve(equation_polynomial, x)
symbolic_solve(exp_eq)
symbolic_solve(expr_with_params, x, dropmultiplicity=false)
symbolic_solve(x^10 - a^10, x, dropmultiplicity=false)
end
Expand Down
1 change: 1 addition & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export Inequality, ≲, ≳
include("inequality.jl")

import Bijections, DynamicPolynomials
export tosymbol
include("utils.jl")

using ConstructionBase
Expand Down
10 changes: 9 additions & 1 deletion src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ end
ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T}

function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m)
args = map(args) do arg
if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic()
return Ref(only(arguments(arg)))
else
return arg
end
end

t = f(args...)
t isa Symbolic && !isnothing(m) ?
metadata(t, m) : t
Expand Down Expand Up @@ -968,7 +976,7 @@ end
### Codegen

function SymbolicUtils.Code.toexpr(x::ArrayOp, st)
haskey(st.symbolify, x) && return st.symbolify[x]
haskey(st.rewrites, x) && return st.rewrites[x]

if iscall(x.term)
toexpr(x.term, st)
Expand Down
6 changes: 2 additions & 4 deletions src/solver/attract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,8 @@ function attract_trig(lhs, var)
r_trig = [@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2=>one(~x))
@acrule(sin(~x::(contains_var))^2 + -1=>-1 * cos(~x)^2)
@acrule(cos(~x::(contains_var))^2 + -1=>-1 * sin(~x)^2)
@acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2 *
~x))
@acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2 *
~x))
@acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2*~x))
@acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2*~x))
@acrule(cos(~x::(contains_var)) * sin(~x::(contains_var))=>sin(2 * ~x) / 2)
@acrule(tan(~x::(contains_var))^2 + -1 * sec(~x::(contains_var))^2=>one(~x))
@acrule(-1 * tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2=>one(~x))
Expand Down
2 changes: 1 addition & 1 deletion src/solver/ia_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ function isolate(lhs, var; warns=true, conditions=[])
new_var = (@variables $new_var)[1]
rhs = map(
sol -> term(rev_oper[oper], sol) +
term(*, Base.MathConstants.pi, 2 * new_var),
term(*, Base.MathConstants.pi, new_var),
rhs)
@info string(new_var) * " ϵ" * " Ζ"

Expand Down
22 changes: 16 additions & 6 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
for e in expr
for var in x
if !check_poly_inunivar(e, var)
warns && @warn("This system can not be currently solved by solve.")
warns && @warn("This system can not be currently solved by `symbolic_solve`.")
return nothing
end
end
Expand Down Expand Up @@ -276,7 +276,7 @@ function solve_univar(expression, x; dropmultiplicity=true)
end
end

subs, filtered_expr = filter_poly(expression, x)
subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true)
coeffs, constant = polynomial_coeffs(filtered_expr, [x])
degree = sdegree(coeffs, x)

Expand All @@ -296,18 +296,28 @@ function solve_univar(expression, x; dropmultiplicity=true)
append!(arr_roots, og_arr_roots)
end
end

return arr_roots
end

if length(factors) != 1
for factor in factors_subbed
roots = solve_univar(factor, x, dropmultiplicity = dropmultiplicity)
for i in eachindex(factors_subbed)
if !any(isequal(x, var) for var in get_variables(factors[i]))
continue
end
roots = solve_univar(factors_subbed[i], x, dropmultiplicity = dropmultiplicity)
append!(arr_roots, roots)
end
end

for i in reverse(eachindex(arr_roots))
for j in eachindex(assumptions)
if isequal(substitute(assumptions[j], Dict(x=>arr_roots[i])), 0)
deleteat!(arr_roots, i)
end
end
end

if isequal(arr_roots, [])
@assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`."
return [RootsOf(wrap(expression), wrap(x))]
end

Expand Down
79 changes: 71 additions & 8 deletions src/solver/postprocess.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Alex: make sure `Num`s are not processed here as they'd break it.
_postprocess_root(x) = x

Expand Down Expand Up @@ -32,30 +31,30 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
!iscall(x) && return x

x = Symbolics.term(operation(x), map(_postprocess_root, arguments(x))...)
oper = operation(x)

# sqrt(0), cbrt(0) => 0
# sqrt(1), cbrt(1) => 1
if iscall(x) &&
(operation(x) === sqrt || operation(x) === cbrt || operation(x) === ssqrt ||
operation(x) === scbrt)
if (oper === sqrt || oper === cbrt || oper === ssqrt ||
oper === scbrt)
arg = arguments(x)[1]
if isequal(arg, 0) || isequal(arg, 1)
return arg
end
end

# (X)^0 => 1
if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 0)
if oper === (^) && isequal(arguments(x)[2], 0)
return 1
end

# (X)^1 => X
if iscall(x) && operation(x) === (^) && isequal(arguments(x)[2], 1)
if oper === (^) && isequal(arguments(x)[2], 1)
return arguments(x)[1]
end

# sqrt((N / D)^2 * M) => N / D * sqrt(M)
if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt)
if (oper === sqrt || oper === ssqrt)
function squarefree_decomp(x::Integer)
square, squarefree = big(1), big(1)
for (p, d) in collect(Primes.factor(abs(x)))
Expand Down Expand Up @@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
end

# (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2))
if iscall(x) && operation(x) === (^)
if oper === (^)
arg1, arg2 = arguments(x)
if iscall(arg1) && (operation(arg1) === sqrt || operation(arg1) === ssqrt)
if arg2 isa Integer
Expand All @@ -105,6 +104,19 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
end
end

x = convert_consts(x)

if oper === (+)
args = arguments(x)
for arg in args
if isequal(arg, 0)
after_removing = setdiff(args, arg)
isone(length(after_removing)) && return after_removing[1]
return Symbolics.term(+, after_removing)
end
end
end

return x
end

Expand All @@ -122,3 +134,54 @@ function postprocess_root(x)
end
x # unreachable
end


inv_exacts = [0, Symbolics.term(*, pi),
Symbolics.term(/,pi,3),
Symbolics.term(/, pi, 2),
Symbolics.term(/, Symbolics.term(*, 2, pi), 3),
Symbolics.term(/, pi, 6),
Symbolics.term(/, Symbolics.term(*, 5, pi), 6),
Symbolics.term(/, pi, 4)
]
inv_evald = Symbolics.symbolic_to_float.(inv_exacts)

const inv_pairs = collect(zip(inv_exacts, inv_evald))
"""
function convert_consts(x)
This function takes BasicSymbolic terms as input (x) and attempts
to simplify these basic symbolic terms using known values.
Currently, this function only supports inverse trigonometric functions.
## Examples
```jldoctest
julia> Symbolics.convert_consts(Symbolics.term(acos, 0))
π / 2
julia> Symbolics.convert_consts(Symbolics.term(atan, 0))
0
julia> Symbolics.convert_consts(Symbolics.term(atan, 1))
π / 4
```
"""
function convert_consts(x)
!iscall(x) && return x

oper = operation(x)
inv_opers = [asin, acos, atan]

if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x))
val = Symbolics.symbolic_to_float(x)
for (exact, evald) in inv_pairs
if isapprox(evald, val)
return exact
elseif isapprox(-evald, val)
return -exact
end
end
end

# add [sin, cos, tan] simplifications in the future?
return x
end
19 changes: 13 additions & 6 deletions src/solver/preprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ function clean_f(filtered_expr, var, subs)
unwrapped_f = unwrap(filtered_expr)
!iscall(unwrapped_f) && return filtered_expr
oper = operation(unwrapped_f)
assumptions = []

if oper === (/)
args = arguments(unwrapped_f)
if any(isequal(var, x) for x in get_variables(args[2]))
return filtered_expr
filtered_expr = expand(args[1] * args[2])
push!(assumptions, substitute(args[2], subs, fold=false))
return filtered_expr, assumptions
end
filtered_expr = args[1]
@info "Assuming $(substitute(args[2], subs, fold=false) != 0)"
end
return filtered_expr
return filtered_expr, assumptions
end

"""
Expand Down Expand Up @@ -238,15 +241,17 @@ julia> filter_poly((x+1)*term(log, 3), x)
(Dict{Any, Any}(var"##247" => log(3)), var"##247"*(1 + x))
```
"""
function filter_poly(og_expr, var)
function filter_poly(og_expr, var; assumptions=false)
expr = deepcopy(og_expr)
expr = unwrap(expr)
vars = get_variables(expr)

# handle edge cases
if !isequal(vars, []) && isequal(vars[1], expr)
assumptions && return Dict{Any, Any}(), expr, []
return (Dict{Any, Any}(), expr)
elseif isequal(vars, [])
assumptions && return filter_stuff(expr), []
return filter_stuff(expr)
end

Expand All @@ -256,14 +261,16 @@ function filter_poly(og_expr, var)
# reassemble expr to avoid variables remembering original values issue and clean
args = arguments(expr)
oper = operation(expr)
new_expr = clean_f(term(oper, args...), var, subs)
new_expr, assum_array = clean_f(term(oper, args...), var, subs)

assumptions && return subs, new_expr, assum_array
return subs, new_expr
end
function filter_poly(og_expr)

function filter_poly(og_expr; assumptions=false)
new_var = gensym()
new_var = (@variables $(new_var))[1]
return filter_poly(og_expr, new_var)
return filter_poly(og_expr, new_var; assumptions=assumptions)
end


Expand Down
2 changes: 1 addition & 1 deletion src/solver/solve_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function check_expr_validity(expr)
valid_type = false

if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} ||
type_expr == Complex{Num} || type_expr == ComplexTerm{Real}
type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}}
valid_type = true
end
iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported"
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi
string(nameof(arguments(oldop)[1]))
elseif oldop == getindex
args = arguments(O)
opname = string(tosymbol(args[1]), "[", map(tosymbol, args[2:end])..., "]")
return _Sym(symtype(O), Symbol(opname, d_separator, ds))
opname = string(tosymbol(args[1]))
return metadata(_Sym(symtype(args[1]), Symbol(opname, d_separator, ds)), metadata(args[1]))[args[2:end]...]
elseif oldop isa Function
return nothing
else
Expand Down
Loading

0 comments on commit da3259e

Please sign in to comment.