Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jverzani committed Sep 13, 2023
1 parent b177158 commit c9a219c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[targets]
test = ["Test"]
test = ["SymbolicUtils", "Test"]
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
using Serialization
import Base: MathConstants.γ, MathConstants.e, MathConstants.φ, MathConstants.catalan

VERSION >= v"1.9" && include("test-SymbolicUtils.jl")
include("test-dense-matrix.jl")

x = symbols("x")
Expand Down
54 changes: 54 additions & 0 deletions test/test-SymbolicUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using Test
using SymEngine
import SymbolicUtils: simplify, @rule, @acrule, Chain, Fixpoint


@testset "SymbolicUtils" begin
# from SymbolicUtils.jl docs
# https://symbolicutils.juliasymbolics.org/rewrite/#rule-based_rewriting
@vars w x y z
@vars α β
@vars a b c d

@test simplify(cos(x)^2 + sin(x)^2) == 1

r1 = @rule sin(2(~x)) => 2sin(~x)*cos(~x)
@test r1(sin(2z)) == 2*cos(z)*sin(z)
@test r1(sin(3z)) === nothing
@test r1(sin(2*(w-z))) == 2cos(w - z)*sin(w - z)
@test r1(sin(2*(w+z)*+β))) === nothing

r2 = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y);
@test r2(sin+β)) == sin(α)*cos(β) + cos(α)*sin(β)

xs = @rule(+(~~xs) => ~~xs)(x + y + z) # segment variable
@test Set(xs) == Set([x,y,z])

r3 = @rule ~x * +(~~ys) => sum(map(y-> ~x * y, ~~ys));
@test r3(2 * (w+w+α+β)) == 4w + 2α + 2β

r4 = @rule ~x + ~~y::(ys->iseven(length(ys))) => "odd terms"; # Predicates for matching

@test r4(a + b + c + d) == nothing
@test r4(b + c + d) == "odd terms"
@test r4(b + c + b) == nothing
@test r4(a + b) == nothing

sqexpand = @rule (~x + ~y)^2 => (~x)^2 + (~y)^2 + 2 * ~x * ~y
@test sqexpand((cos(x) + sin(x))^2) == cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x)

pyid = @rule sin(~x)^2 + cos(~x)^2 => 1
@test_broken pyid(cos(x)^2 + sin(x)^2) === nothing # order should matter, but this works

acpyid = @acrule sin(~x)^2 + cos(~x)^2 => 1 # acrule is commutative
@test acpyid(cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x)) == 1 + 2cos(x)*sin(x)

csa = Chain([sqexpand, acpyid]) # chain composes rules
@test csa((cos(x) + sin(x))^2) == 1 + 2cos(x)*sin(x)

cas = Chain([acpyid, sqexpand]) # order matters
@test cas((cos(x) + sin(x))^2) == cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x)

@test Fixpoint(cas)((cos(x) + sin(x))^2) == 1 + 2cos(x)*sin(x)

end

0 comments on commit c9a219c

Please sign in to comment.