Skip to content

Commit

Permalink
Merge branch 'master' into new_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed authored Sep 8, 2024
2 parents a40fa9b + 81d894b commit f827c89
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 115 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <gowda@mit.edu>"]
version = "6.6.0"
version = "6.11.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -44,15 +44,15 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsForwardDiffExt = "ForwardDiff"
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsLuxExt = "Lux"
SymbolicsNemoExt = "Nemo"
SymbolicsPreallocationToolsExt = ["PreallocationTools", "ForwardDiff"]
SymbolicsSymPyExt = "SymPy"
Expand All @@ -76,7 +76,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.11"
Lux = "1"
MacroTools = "0.5"
NaNMath = "1"
Nemo = "0.45, 0.46"
Expand Down
22 changes: 18 additions & 4 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ end
# Given a GB in k[params][vars] produces a GB in k(params)[vars]
function demote(gb, vars::Vector{Num}, params::Vector{Num})
isequal(gb, [1]) && return gb

gb = Symbolics.wrap.(SymbolicUtils.toterm.(gb))
Symbolics.check_polynomial.(gb)

Expand All @@ -126,7 +127,7 @@ function demote(gb, vars::Vector{Num}, params::Vector{Num})
ring_param, params_demoted = Nemo.polynomial_ring(Nemo.base_ring(ring_flat), map(string, nemo_params))
ring_demoted, vars_demoted = Nemo.polynomial_ring(Nemo.fraction_field(ring_param), map(string, nemo_vars), internal_ordering=:lex)
varmap = Dict((nemo_vars .=> vars_demoted)..., (nemo_params .=> params_demoted)...)
gb_demoted = map(f -> nemo_crude_evaluate(f, varmap), nemo_gb)
gb_demoted = map(f -> ring_demoted(nemo_crude_evaluate(f, varmap)), nemo_gb)
result = empty(gb_demoted)
while true
gb_demoted = map(f -> Nemo.map_coefficients(c -> c // Nemo.leading_coefficient(f), f), gb_demoted)
Expand Down Expand Up @@ -176,6 +177,7 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
# Use a new variable to separate the input polynomials (Reference above)
new_var = gen_separating_var(vars)
old_len = length(vars)
old_vars = deepcopy(vars)
vars = vcat(vars, new_var)

new_eqs = []
Expand Down Expand Up @@ -204,6 +206,13 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
return []
end

for i in reverse(eachindex(new_eqs))
all_present = Symbolics.get_variables(new_eqs[i])
if length(intersect(all_present, vars)) < 1
deleteat!(new_eqs, i)
end
end

new_eqs = demote(new_eqs, vars, params)
new_eqs = map(Symbolics.unwrap, new_eqs)

Expand Down Expand Up @@ -233,7 +242,10 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
end

# non-cyclic case
n_iterations > 10 && return []
if n_iterations > 10
warns && @warn("symbolic_solve can not currently solve this system of polynomials.")
return nothing
end

n_iterations += 1
end
Expand Down Expand Up @@ -295,11 +307,13 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
isempty(tr_basis) && return nothing
vars_gen = setdiff(vars, tr_basis)
sol = solve_zerodim(eqs, vars_gen; dropmultiplicity=dropmultiplicity, warns=warns)

for roots in sol
for x in tr_basis
roots[x] = x
end
end

sol
end

Expand All @@ -313,8 +327,8 @@ PrecompileTools.@setup_workload begin
PrecompileTools.@compile_workload begin
symbolic_solve(equation1, x)
symbolic_solve(equation_actually_polynomial)
symbolic_solve(simple_linear_equations, [x, y])
symbolic_solve(equations_intersect_sphere_line, [x, y, z])
symbolic_solve(simple_linear_equations, [x, y], warns=false)
symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false)
end
end

Expand Down
11 changes: 0 additions & 11 deletions ext/SymbolicsLuxCoreExt.jl

This file was deleted.

18 changes: 18 additions & 0 deletions ext/SymbolicsLuxExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module SymbolicsLuxExt

using Lux
using Symbolics
using Lux.LuxCore
using Symbolics.SymbolicUtils

function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}})
Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x))
end

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng())
eltype = Real
end

end
19 changes: 0 additions & 19 deletions ext/SymbolicsNemoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,13 @@ function Symbolics.factor_use_nemo(poly::Num)
return sym_unit, sym_factors
end

# gcd(x^2 - y^2, x^3 - y^3) -> x - y
function Symbolics.gcd_use_nemo(poly1::Num, poly2::Num)
Symbolics.check_polynomial(poly1)
Symbolics.check_polynomial(poly2)
vars1 = Symbolics.get_variables(poly1)
vars2 = Symbolics.get_variables(poly2)
vars = vcat(vars1, vars2)
nemo_ring, nemo_vars = Nemo.polynomial_ring(Nemo.QQ, map(string, vars))
sym_to_nemo = Dict(vars .=> nemo_vars)
nemo_to_sym = Dict(v => k for (k, v) in sym_to_nemo)
nemo_poly1 = Symbolics.substitute(poly1, sym_to_nemo)
nemo_poly2 = Symbolics.substitute(poly2, sym_to_nemo)
nemo_gcd = Nemo.gcd(nemo_poly1, nemo_poly2)
sym_gcd = Symbolics.wrap(nemo_crude_evaluate(nemo_gcd, nemo_to_sym))
return sym_gcd
end


# Helps with precompilation time
PrecompileTools.@setup_workload begin
@variables a b c x y z
expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a))
PrecompileTools.@compile_workload begin
symbolic_solve(expr_with_params, x, dropmultiplicity=false)
symbolic_solve(x^10 - a^10, x, dropmultiplicity=false)
symbolic_solve([x^2 - a^2, x + a], x)
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ SymbolicUtils.sorted_arguments(s::ArrayOp) = sorted_arguments(s.term)

shape(aop::ArrayOp) = aop.shape

SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.ArrayOp}) = ArraySymbolic()

const show_arrayop = Ref{Bool}(false)
function Base.show(io::IO, aop::ArrayOp)
if iscall(aop.term) && !show_arrayop[]
Expand Down
10 changes: 10 additions & 0 deletions src/extra_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ end
@register_symbolic Base.rand(x)
@register_symbolic Base.randn(x)

@register_symbolic Base.clamp(x, y, z)

function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1})
x, l, h = args
T = promote_type(symtype(x), symtype(l), symtype(h))
z = zero(T)
o = one(T)
ifelse(x<l, z, ifelse(x>h, z, o))
end

@register_symbolic Distributions.pdf(dist,x)
@register_symbolic Distributions.logpdf(dist,x)
@register_symbolic Distributions.cdf(dist,x)
Expand Down
78 changes: 18 additions & 60 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
for var in x
check_x(var)
end
if length(x) == 1
x = x[1]
x_univar = true
end
end

if !(expr isa Vector)
Expand All @@ -181,31 +177,21 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
expr = [expr]
expr_univar = false
end
if !expr_univar && x_univar
x = [x]
x_univar = false
end

if x_univar
sols = []
if expr_univar
sols = check_poly_inunivar(expr, x) ?
solve_univar(expr, x, dropmultiplicity = dropmultiplicity) :
ia_solve(expr, x, warns = warns)
isequal(sols, nothing) && return nothing
else
for i in eachindex(expr)
if !check_poly_inunivar(expr[i], x)
warns && @warn("Solve can not solve this input currently")
return nothing
end
end
sols = solve_multipoly(
expr, x, dropmultiplicity = dropmultiplicity, warns = warns)
isequal(sols, nothing) && return nothing
end

sols = check_poly_inunivar(expr, x) ?
solve_univar(expr, x, dropmultiplicity = dropmultiplicity) :
ia_solve(expr, x, warns = warns)
isequal(sols, nothing) && return nothing
sols = map(postprocess_root, sols)
return sols
end

if !expr_univar && !x_univar
if !x_univar
for e in expr
for var in x
if !check_poly_inunivar(e, var)
Expand All @@ -215,11 +201,15 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
end
end

sols = solve_multivar(expr, x, dropmultiplicity = dropmultiplicity)
sols = solve_multivar(expr, x, dropmultiplicity=dropmultiplicity, warns=warns)
isequal(sols, nothing) && return nothing
for sol in sols
sols = convert(Vector{Any}, sols)
for i in eachindex(sols)
for var in x
sol[var] = postprocess_root(sol[var])
sols[i][var] = postprocess_root(sols[i][var])
end
if length(collect(keys(sols[i]))) == 1
sols[i] = collect(values(sols[i]))[1]
end
end

Expand All @@ -243,6 +233,7 @@ function symbolic_solve(expr; x...)
vars = wrap.(vars)
@assert all(v isa Num for v in vars) "All variables should be Nums or BasicSymbolics"

vars = isone(length(vars)) ? vars[1] : vars
return symbolic_solve(expr, vars; x...)
end

Expand All @@ -268,7 +259,7 @@ implemented in the function `get_roots` and its children.
# Examples
"""
function solve_univar(expression, x; dropmultiplicity = true)
function solve_univar(expression, x; dropmultiplicity=true)
args = []
mult_n = 1
expression = unwrap(expression)
Expand Down Expand Up @@ -323,39 +314,6 @@ function solve_univar(expression, x; dropmultiplicity = true)
return arr_roots
end

# You can compute the GCD between a system of polynomials by doing the following:
# Get the GCD between the first two polys,
# and get the GCD between this result and the following index,
# say: solve([x^2 - 1, x - 1, (x-1)^20], x)
# the GCD between the first two terms is obviously x-1,
# now we call gcd_use_nemo() on this term, and the following,
# gcd_use_nemo(x - 1, (x-1)^20), which is again x-1.
# now we just need to solve(x-1, x) to get the common root in this
# system of equations.
function solve_multipoly(polys::Vector, x::Num; dropmultiplicity = true, warns = true)
polys = unique(polys)

if length(polys) < 1
warns && @warn("No expressions entered")
return nothing
end
if length(polys) == 1
return solve_univar(polys[1], x, dropmultiplicity = dropmultiplicity)
end

gcd = gcd_use_nemo(polys[1], polys[2])

for i in eachindex(polys)[3:end]
gcd = gcd_use_nemo(gcd, polys[i])
end

if isequal(gcd, 1)
return []
end

return solve_univar(gcd, x, dropmultiplicity = dropmultiplicity)
end

function solve_multivar(eqs::Any, vars::Any; dropmultiplicity = true, warns = true)
throw("Groebner bases engine is required. Execute `using Groebner` to enable this functionality.")
end
4 changes: 0 additions & 4 deletions src/solver/nemo_stuff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ function factor_use_nemo(poly::Any)
throw("Nemo is required. Execute `using Nemo` to enable this functionality.")
end

# gcd(x^2 - y^2, x^3 - y^3) -> x - y
function gcd_use_nemo(poly1::Any, poly2::Any)
throw("Nemo is required. Execute `using Nemo` to enable this functionality.")
end
2 changes: 2 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ SymbolicUtils.Code.toexpr(x::CallWithMetadata, st) = SymbolicUtils.Code.toexpr(x

CallWithMetadata(f) = CallWithMetadata(f, nothing)

SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic()

function Base.show(io::IO, c::CallWithMetadata)
show(io, c.f)
print(io, "")
Expand Down
3 changes: 3 additions & 0 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ x = Num.(randn(10))
@test norm(x, 1) == norm(Symbolics.value.(x), 1)
@test norm(x, 1.2) == norm(Symbolics.value.(x), 1.2)

@test clamp.(x, 0, 1) == clamp.(Symbolics.value.(x), 0, 1)
@test isequal(Symbolics.derivative(clamp(a, 0, 1), a), ifelse(a < 0, 0, ifelse(a>1, 0, 1)))

@variables x[1:2]
@test isequal(scalarize(norm(x)), sqrt(abs2(x[1]) + abs2(x[2])))
@test isequal(scalarize(norm(x, Inf)), max(abs(x[1]), abs(x[2])))
Expand Down
Loading

0 comments on commit f827c89

Please sign in to comment.