diff --git a/Project.toml b/Project.toml index 1f50aa38e..3b20c2408 100644 --- a/Project.toml +++ b/Project.toml @@ -50,8 +50,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "Test", "Zygote"] +test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"] diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 2655263b2..1fb7d777a 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -27,6 +27,8 @@ include("methods.jl") # LinkedList, simplification utilities include("utils.jl") +# Tree inspection +include("inspect.jl") export Rewriters # A library for composing together expr -> expr functions diff --git a/src/inspect.jl b/src/inspect.jl new file mode 100644 index 000000000..f62551893 --- /dev/null +++ b/src/inspect.jl @@ -0,0 +1,72 @@ +import AbstractTrees + +const inspect_metadata = Ref{Bool}(false) +function AbstractTrees.nodevalue(x::Symbolic) + istree(x) ? operation(x) : x +end + +function AbstractTrees.nodevalue(x::BasicSymbolic) + str = if !istree(x) + string(exprtype(x), "(", x, ")") + elseif isadd(x) + string(exprtype(x), + (scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict))) + elseif ismul(x) + string(exprtype(x), + (scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict))) + elseif isdiv(x) || ispow(x) + string(exprtype(x)) + else + string(exprtype(x),"{", operation(x), "}") + end + + if inspect_metadata[] && !isnothing(metadata(x)) + str *= string(" metadata=", Tuple(k=>v for (k, v) in metadata(x))) + end + Text(str) +end + +function AbstractTrees.children(x::Symbolic) + istree(x) ? arguments(x) : () +end + +""" + inspect([io::IO=stdout], expr; hint=true, metadata=false) + +Inspect an expression tree `expr`. Uses AbstractTrees to print out an expression. + +BasicSymbolic expressions will print the Unityper type (ADD, MUL, DIV, POW, SYM, TERM) and the relevant internals as the head, and the children in the subsequent lines as accessed by `arguments`. Other types will get printed as subtrees. Set `metadata=true` to print any metadata carried by the nodes. + +Line numbers will be shown, use `pluck(expr, line_number)` to get the sub expression or leafnode starting at line_number. +""" +function inspect end + +function inspect(io::IO, x::Symbolic; + hint=true, + metadata=inspect_metadata[]) + + prev_state = inspect_metadata[] + inspect_metadata[] = metadata + lines = readlines(IOBuffer(sprint(io->AbstractTrees.print_tree(io, x)))) + inspect_metadata[] = prev_state + digits = ceil(Int, log10(length(lines))) + line_numbers = lpad.(string.(1:length(lines)), digits) + print(io, join(string.(line_numbers, " ", lines), "\n")) + hint && print(io, "\n\nHint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number") +end + +function inspect(x; hint=true, metadata=inspect_metadata[]) + inspect(stdout, x; hint=hint, metadata=metadata) +end + +inspect(io::IO, x; kw...) = println(io, "Not Symbolic: $x") + +""" + pluck(expr, n) + +Pluck the `n`th subexpression from `expr` as given by pre-order DFS. +This is the same as the node numbering in `inspect`. +""" +function pluck(x, item) + collect(Iterators.take(AbstractTrees.PreOrderDFS(x), item))[end] +end diff --git a/src/types.jl b/src/types.jl index e7f3e3ccd..81c161d12 100644 --- a/src/types.jl +++ b/src/types.jl @@ -559,54 +559,6 @@ function basic_similarterm(t, f, args, stype; metadata=nothing) end end -### -### Tree print -### - -import AbstractTrees - -struct TreePrint - op - x -end - -function AbstractTrees.children(x::BasicSymbolic) - if isterm(x) || ispow(x) - return arguments(x) - elseif isadd(x) || ismul(x) - children = Any[x.coeff] - for (key, coeff) in pairs(x.dict) - if coeff == 1 - push!(children, key) - else - push!(children, TreePrint(isadd(x) ? (:*) : (:^), (key, coeff))) - end - end - return children - end -end - -AbstractTrees.children(x::TreePrint) = [x.x[1], x.x[2]] - -print_tree(x; show_type=false, maxdepth=Inf, kw...) = print_tree(stdout, x; show_type=show_type, maxdepth=maxdepth, kw...) - -function print_tree(_io::IO, x::BasicSymbolic; show_type=false, kw...) - if isterm(x) || isadd(x) || ismul(x) || ispow(x) || isdiv(x) - AbstractTrees.print_tree(_io, x; withinds=true, kw...) do io, y, inds - if istree(y) - print(io, operation(y)) - elseif y isa TreePrint - print(io, "(", y.op, ")") - else - print(io, y) - end - if !(y isa TreePrint) && show_type - print(io, " [", typeof(y), "]") - end - end - end -end - ### ### Metadata ### diff --git a/test/basics.jl b/test/basics.jl index ad11ee90c..bd943396a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -188,6 +188,18 @@ end @test repr((-1)^a) == "(-1)^a" end +@testset "inspect" begin + @syms x y z + y = SymbolicUtils.setmetadata(y, Integer, 42) # Set some metadata + ex = z*(2x + 3y + 1)^2/(z+2x) + @test_reference "inspect_output/ex.txt" sprint(io->SymbolicUtils.inspect(io, ex)) + @test_reference "inspect_output/ex-md.txt" sprint(io->SymbolicUtils.inspect(io, ex, metadata=true)) + @test_reference "inspect_output/ex-nohint.txt" sprint(io->SymbolicUtils.inspect(io, ex, hint=false)) + @test SymbolicUtils.pluck(ex, 8) == 2 + @test_reference "inspect_output/sub10.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 10))) + @test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14))) +end + @testset "similarterm" begin @syms a b c @test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1)) diff --git a/test/fuzz.jl b/test/fuzz.jl index 8303fd165..76edefca1 100644 --- a/test/fuzz.jl +++ b/test/fuzz.jl @@ -2,8 +2,7 @@ include("fuzzlib.jl") using Random: seed! -seed!(6174) -@testset "Fuzz test" begin +seed!(8258) @time @testset "expand fuzz" begin for i=1:500 i % 100 == 0 && @info "expand fuzz" iter=i @@ -45,4 +44,3 @@ seed!(6174) fuzz_addmulpow(4) end end -end diff --git a/test/fuzzlib.jl b/test/fuzzlib.jl index 80133ad78..adac553cf 100644 --- a/test/fuzzlib.jl +++ b/test/fuzzlib.jl @@ -43,7 +43,7 @@ const num_spec = let ()->rand([a b c d e f])] binops = SymbolicUtils.diadic - nopow = filter(x->x!==(^), binops) + nopow = setdiff(binops, [(^), besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk]) twoargfns = vcat(nopow, (x,y)->x isa Union{Int, Rational, Complex{<:Rational}} ? x * y : x^y) fns = vcat(1 .=> vcat(SymbolicUtils.monadic, [one, zero]), 2 .=> vcat(twoargfns, fill(+, 5), [-,-], fill(*, 5), fill(/, 40)), diff --git a/test/inspect_output/ex-md.txt b/test/inspect_output/ex-md.txt new file mode 100644 index 000000000..4d7a330f3 --- /dev/null +++ b/test/inspect_output/ex-md.txt @@ -0,0 +1,20 @@ + 1 DIV + 2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2)) + 3 │ ├─ SYM(z) + 4 │ └─ POW + 5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2)) + 6 │ │ ├─ 1 + 7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,)) + 8 │ │ │ ├─ 2 + 9 │ │ │ └─ SYM(x) +10 │ │ └─ MUL(scalar = 3, powers = (y => 1,)) +11 │ │ ├─ 3 +12 │ │ └─ SYM(y) metadata=(Integer => 42,) +13 │ └─ 2 +14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2)) +15 ├─ SYM(z) +16 └─ MUL(scalar = 2, powers = (x => 1,)) +17 ├─ 2 +18 └─ SYM(x) + +Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/inspect_output/ex-nohint.txt b/test/inspect_output/ex-nohint.txt new file mode 100644 index 000000000..43da94e62 --- /dev/null +++ b/test/inspect_output/ex-nohint.txt @@ -0,0 +1,18 @@ + 1 DIV + 2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2)) + 3 │ ├─ SYM(z) + 4 │ └─ POW + 5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2)) + 6 │ │ ├─ 1 + 7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,)) + 8 │ │ │ ├─ 2 + 9 │ │ │ └─ SYM(x) +10 │ │ └─ MUL(scalar = 3, powers = (y => 1,)) +11 │ │ ├─ 3 +12 │ │ └─ SYM(y) +13 │ └─ 2 +14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2)) +15 ├─ SYM(z) +16 └─ MUL(scalar = 2, powers = (x => 1,)) +17 ├─ 2 +18 └─ SYM(x) \ No newline at end of file diff --git a/test/inspect_output/ex.txt b/test/inspect_output/ex.txt new file mode 100644 index 000000000..9a5c5c4fa --- /dev/null +++ b/test/inspect_output/ex.txt @@ -0,0 +1,20 @@ + 1 DIV + 2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2)) + 3 │ ├─ SYM(z) + 4 │ └─ POW + 5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2)) + 6 │ │ ├─ 1 + 7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,)) + 8 │ │ │ ├─ 2 + 9 │ │ │ └─ SYM(x) +10 │ │ └─ MUL(scalar = 3, powers = (y => 1,)) +11 │ │ ├─ 3 +12 │ │ └─ SYM(y) +13 │ └─ 2 +14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2)) +15 ├─ SYM(z) +16 └─ MUL(scalar = 2, powers = (x => 1,)) +17 ├─ 2 +18 └─ SYM(x) + +Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/inspect_output/sub10.txt b/test/inspect_output/sub10.txt new file mode 100644 index 000000000..f651d4312 --- /dev/null +++ b/test/inspect_output/sub10.txt @@ -0,0 +1,5 @@ +1 MUL(scalar = 3, powers = (y => 1,)) +2 ├─ 3 +3 └─ SYM(y) + +Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/inspect_output/sub14.txt b/test/inspect_output/sub14.txt new file mode 100644 index 000000000..e905378e9 --- /dev/null +++ b/test/inspect_output/sub14.txt @@ -0,0 +1,7 @@ +1 ADD(scalar = 0, coeffs = (z => 1, x => 2)) +2 ├─ SYM(z) +3 └─ MUL(scalar = 2, powers = (x => 1,)) +4 ├─ 2 +5 └─ SYM(x) + +Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0c9ee4792..004b26b7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Documenter using Pkg using Test using SymbolicUtils +using ReferenceTests import IfElse: ifelse DocMeta.setdocmeta!(