Skip to content

Commit

Permalink
Merge pull request #268 from jverzani/symbolic-utils-extension_fix
Browse files Browse the repository at this point in the history
Symbolic utils extension fix
  • Loading branch information
isuruf authored Oct 7, 2023
2 parents 71c1e71 + a34c325 commit 90e42f3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 26 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Compat = "0.63.0, 1, 2, 3, 4"
RecipesBase = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1.0"
SpecialFunctions = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1, 2"
SymEngine_jll = "0.9, 0.10"
SymbolicUtils = "1.4"
julia = "1.6"

[extras]
Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8"

[compat]
Documenter = "1"
8 changes: 4 additions & 4 deletions docs/src/basicUsage.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Vectors can be defined through list comprehension and string interpolation.
julia> using SymEngine
julia> [symbols("α_$i") for i in 1:3]
3-element Array{Basic,1}:
3-element Vector{Basic}:
α_1
α_2
α_3
Expand All @@ -56,7 +56,7 @@ In an analogous manner, matrices are declared with a combination of string inter
julia> using SymEngine
julia> W = [symbols("W_$i$j") for i in 1:3, j in 1:4]
3×4 Array{Basic,2}:
3×4 Matrix{Basic}:
W_11 W_12 W_13 W_14
W_21 W_22 W_23 W_24
W_31 W_32 W_33 W_34
Expand All @@ -70,13 +70,13 @@ Consider the canonical example of **matrix vector multiplication**.
julia> using SymEngine
julia> W = [symbols("W_$i$j") for i in 1:3, j in 1:4]
3×4 Array{Basic,2}:
3×4 Matrix{Basic}:
W_11 W_12 W_13 W_14
W_21 W_22 W_23 W_24
W_31 W_32 W_33 W_34
julia> W*[1.0, 2.0, 3.0, 4.0]
3-element Array{Basic,1}:
3-element Vector{Basic}:
1.0*W_11 + 2.0*W_12 + 3.0*W_13 + 4.0*W_14
1.0*W_21 + 2.0*W_22 + 3.0*W_23 + 4.0*W_24
1.0*W_31 + 2.0*W_32 + 3.0*W_33 + 4.0*W_34
Expand Down
32 changes: 10 additions & 22 deletions ext/SymEngineSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Check if x represents an expression tree. If returns true, it will be assumed th
function SymbolicUtils.istree(x::SymEngine.SymbolicType)
cls = SymEngine.get_symengine_class(x)
cls == :Symbol && return false
cls == :Constant && return false
any(==(cls), SymEngine.number_types) && return false
return true
end
Expand Down Expand Up @@ -72,32 +73,19 @@ end

# Needed for some simplification routines
# a total order <ₑ
import SymbolicUtils: <ₑ, isterm, isadd, ismul, issym, cmp_mul_adds, cmp_term_term
import SymbolicUtils: <ₑ, isterm, isadd, ismul, issym, get_degrees, monomial_lt, _arglen
function SymbolicUtils.:<(a::SymEngine.Basic, b::SymEngine.Basic)
if isterm(a) && !isterm(b)
return false
elseif isterm(b) && !isterm(a)
return true
elseif (isadd(a) || ismul(a)) && (isadd(b) || ismul(b))
return cmp_mul_adds(a, b)
elseif issym(a) && issym(b)
nameof(a) < nameof(b)
elseif !istree(a) && !istree(b)
T = typeof(a)
S = typeof(b)
if T == S
is_number(a) && is_number(b) && return N(a) < N(b)
return hash(a) < hash(b)
da, db = get_degrees(a), get_degrees(b)
fw = monomial_lt(da, db)
bw = monomial_lt(db, da)
if fw === bw && !isequal(a, b)
if _arglen(a) == _arglen(b)
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
else
return name(T) < nameof(S)
return _arglen(a) < _arglen(b)
end
#return T===S ? (T <: Number ? isless(a, b) : hash(a) < hash(b)) : nameof(T) < nameof(S)
elseif istree(b) && !istree(a)
return true
elseif istree(a) && istree(b)
return cmp_term_term(a,b)
else
return !(b <ₑ a)
return fw
end
end

Expand Down

0 comments on commit 90e42f3

Please sign in to comment.