Skip to content

Commit

Permalink
Merge pull request #538 from JuliaSymbolics/s/inspect
Browse files Browse the repository at this point in the history
inspect & pluck
  • Loading branch information
shashi authored Sep 1, 2023
2 parents 8f9f5f1 + 33278b6 commit a94caad
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 53 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "Test", "Zygote"]
test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]
2 changes: 2 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ include("methods.jl")
# LinkedList, simplification utilities
include("utils.jl")

# Tree inspection
include("inspect.jl")
export Rewriters

# A library for composing together expr -> expr functions
Expand Down
72 changes: 72 additions & 0 deletions src/inspect.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import AbstractTrees

const inspect_metadata = Ref{Bool}(false)
function AbstractTrees.nodevalue(x::Symbolic)
istree(x) ? operation(x) : x
end

function AbstractTrees.nodevalue(x::BasicSymbolic)
str = if !istree(x)
string(exprtype(x), "(", x, ")")
elseif isadd(x)
string(exprtype(x),
(scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
elseif ismul(x)
string(exprtype(x),
(scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
elseif isdiv(x) || ispow(x)
string(exprtype(x))
else
string(exprtype(x),"{", operation(x), "}")
end

if inspect_metadata[] && !isnothing(metadata(x))
str *= string(" metadata=", Tuple(k=>v for (k, v) in metadata(x)))
end
Text(str)
end

function AbstractTrees.children(x::Symbolic)
istree(x) ? arguments(x) : ()
end

"""
inspect([io::IO=stdout], expr; hint=true, metadata=false)
Inspect an expression tree `expr`. Uses AbstractTrees to print out an expression.
BasicSymbolic expressions will print the Unityper type (ADD, MUL, DIV, POW, SYM, TERM) and the relevant internals as the head, and the children in the subsequent lines as accessed by `arguments`. Other types will get printed as subtrees. Set `metadata=true` to print any metadata carried by the nodes.
Line numbers will be shown, use `pluck(expr, line_number)` to get the sub expression or leafnode starting at line_number.
"""
function inspect end

function inspect(io::IO, x::Symbolic;
hint=true,
metadata=inspect_metadata[])

prev_state = inspect_metadata[]
inspect_metadata[] = metadata
lines = readlines(IOBuffer(sprint(io->AbstractTrees.print_tree(io, x))))
inspect_metadata[] = prev_state
digits = ceil(Int, log10(length(lines)))
line_numbers = lpad.(string.(1:length(lines)), digits)
print(io, join(string.(line_numbers, " ", lines), "\n"))
hint && print(io, "\n\nHint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number")
end

function inspect(x; hint=true, metadata=inspect_metadata[])
inspect(stdout, x; hint=hint, metadata=metadata)
end

inspect(io::IO, x; kw...) = println(io, "Not Symbolic: $x")

"""
pluck(expr, n)
Pluck the `n`th subexpression from `expr` as given by pre-order DFS.
This is the same as the node numbering in `inspect`.
"""
function pluck(x, item)
collect(Iterators.take(AbstractTrees.PreOrderDFS(x), item))[end]
end
48 changes: 0 additions & 48 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -559,54 +559,6 @@ function basic_similarterm(t, f, args, stype; metadata=nothing)
end
end

###
### Tree print
###

import AbstractTrees

struct TreePrint
op
x
end

function AbstractTrees.children(x::BasicSymbolic)
if isterm(x) || ispow(x)
return arguments(x)
elseif isadd(x) || ismul(x)
children = Any[x.coeff]
for (key, coeff) in pairs(x.dict)
if coeff == 1
push!(children, key)
else
push!(children, TreePrint(isadd(x) ? (:*) : (:^), (key, coeff)))
end
end
return children
end
end

AbstractTrees.children(x::TreePrint) = [x.x[1], x.x[2]]

print_tree(x; show_type=false, maxdepth=Inf, kw...) = print_tree(stdout, x; show_type=show_type, maxdepth=maxdepth, kw...)

function print_tree(_io::IO, x::BasicSymbolic; show_type=false, kw...)
if isterm(x) || isadd(x) || ismul(x) || ispow(x) || isdiv(x)
AbstractTrees.print_tree(_io, x; withinds=true, kw...) do io, y, inds
if istree(y)
print(io, operation(y))
elseif y isa TreePrint
print(io, "(", y.op, ")")
else
print(io, y)
end
if !(y isa TreePrint) && show_type
print(io, " [", typeof(y), "]")
end
end
end
end

###
### Metadata
###
Expand Down
12 changes: 12 additions & 0 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ end
@test repr((-1)^a) == "(-1)^a"
end

@testset "inspect" begin
@syms x y z
y = SymbolicUtils.setmetadata(y, Integer, 42) # Set some metadata
ex = z*(2x + 3y + 1)^2/(z+2x)
@test_reference "inspect_output/ex.txt" sprint(io->SymbolicUtils.inspect(io, ex))
@test_reference "inspect_output/ex-md.txt" sprint(io->SymbolicUtils.inspect(io, ex, metadata=true))
@test_reference "inspect_output/ex-nohint.txt" sprint(io->SymbolicUtils.inspect(io, ex, hint=false))
@test SymbolicUtils.pluck(ex, 8) == 2
@test_reference "inspect_output/sub10.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 10)))
@test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14)))
end

@testset "similarterm" begin
@syms a b c
@test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1))
Expand Down
4 changes: 1 addition & 3 deletions test/fuzz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ include("fuzzlib.jl")

using Random: seed!

seed!(6174)
@testset "Fuzz test" begin
seed!(8258)
@time @testset "expand fuzz" begin
for i=1:500
i % 100 == 0 && @info "expand fuzz" iter=i
Expand Down Expand Up @@ -45,4 +44,3 @@ seed!(6174)
fuzz_addmulpow(4)
end
end
end
2 changes: 1 addition & 1 deletion test/fuzzlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const num_spec = let
()->rand([a b c d e f])]

binops = SymbolicUtils.diadic
nopow = filter(x->x!==(^), binops)
nopow = setdiff(binops, [(^), besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk])
twoargfns = vcat(nopow, (x,y)->x isa Union{Int, Rational, Complex{<:Rational}} ? x * y : x^y)
fns = vcat(1 .=> vcat(SymbolicUtils.monadic, [one, zero]),
2 .=> vcat(twoargfns, fill(+, 5), [-,-], fill(*, 5), fill(/, 40)),
Expand Down
20 changes: 20 additions & 0 deletions test/inspect_output/ex-md.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
1 DIV
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3 │ ├─ SYM(z)
4 │ └─ POW
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6 │ │ ├─ 1
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8 │ │ │ ├─ 2
9 │ │ │ └─ SYM(x)
10 │ │ └─ MUL(scalar = 3, powers = (y => 1,))
11 │ │ ├─ 3
12 │ │ └─ SYM(y) metadata=(Integer => 42,)
13 │ └─ 2
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15 ├─ SYM(z)
16 └─ MUL(scalar = 2, powers = (x => 1,))
17 ├─ 2
18 └─ SYM(x)

Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number
18 changes: 18 additions & 0 deletions test/inspect_output/ex-nohint.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
1 DIV
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3 │ ├─ SYM(z)
4 │ └─ POW
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6 │ │ ├─ 1
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8 │ │ │ ├─ 2
9 │ │ │ └─ SYM(x)
10 │ │ └─ MUL(scalar = 3, powers = (y => 1,))
11 │ │ ├─ 3
12 │ │ └─ SYM(y)
13 │ └─ 2
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15 ├─ SYM(z)
16 └─ MUL(scalar = 2, powers = (x => 1,))
17 ├─ 2
18 └─ SYM(x)
20 changes: 20 additions & 0 deletions test/inspect_output/ex.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
1 DIV
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3 │ ├─ SYM(z)
4 │ └─ POW
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6 │ │ ├─ 1
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8 │ │ │ ├─ 2
9 │ │ │ └─ SYM(x)
10 │ │ └─ MUL(scalar = 3, powers = (y => 1,))
11 │ │ ├─ 3
12 │ │ └─ SYM(y)
13 │ └─ 2
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15 ├─ SYM(z)
16 └─ MUL(scalar = 2, powers = (x => 1,))
17 ├─ 2
18 └─ SYM(x)

Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number
5 changes: 5 additions & 0 deletions test/inspect_output/sub10.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1 MUL(scalar = 3, powers = (y => 1,))
2 ├─ 3
3 └─ SYM(y)

Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number
7 changes: 7 additions & 0 deletions test/inspect_output/sub14.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
1 ADD(scalar = 0, coeffs = (z => 1, x => 2))
2 ├─ SYM(z)
3 └─ MUL(scalar = 2, powers = (x => 1,))
4 ├─ 2
5 └─ SYM(x)

Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Documenter
using Pkg
using Test
using SymbolicUtils
using ReferenceTests
import IfElse: ifelse

DocMeta.setdocmeta!(
Expand Down

0 comments on commit a94caad

Please sign in to comment.