Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Expronicon.jl in SymbolicUtils.jl #1272

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import DomainSets: Domain
using TermInterface
import TermInterface: maketerm, iscall, operation, arguments, metadata

import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic,
FnType, @rule, Rewriters, substitute, symtype,
promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv
import SymbolicUtils: BasicSymbolic, FnType, @rule, Rewriters, substitute, symtype,
promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, _Sym,
_Term, get_dict, isconst, get_val

using SymbolicUtils.Code

Expand Down
29 changes: 18 additions & 11 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ end

function Base.getindex(x::SymArray, idx...)
idx = unwrap.(idx)
idx = map(idx) do i
if isconst(i)
get_val(i)
else
i
end
end
meta = metadata(unwrap(x))
if iscall(x) && (op = operation(x)) isa Operator
args = arguments(x)
Expand All @@ -32,7 +39,7 @@ function Base.getindex(x::SymArray, idx...)
throw(BoundsError(x, idx))
end
end
res = Term{eltype(symtype(x))}(getindex, [x, Tuple(ii)...]; metadata = meta)
res = _Term(eltype(symtype(x)), getindex, [x, Tuple(ii)...]; metadata = meta)
elseif all(i -> symtype(i) <: Integer, idx)
shape(x) !== Unknown() && @boundscheck begin
if length(idx) > 1
Expand All @@ -43,7 +50,7 @@ function Base.getindex(x::SymArray, idx...)
end
end
end
res = Term{eltype(symtype(x))}(getindex, [x, idx...]; metadata = meta)
res = _Term(eltype(symtype(x)), getindex, [x, idx...]; metadata = meta)
elseif length(idx) == 1 && symtype(first(idx)) <: CartesianIndex
i = first(idx)
ii = i isa CartesianIndex ? Tuple(i) : arguments(i)
Expand All @@ -70,7 +77,7 @@ function Base.getindex(x::SymArray, idx...)
end
end

term = Term{Any}(getindex, [x, idx...]; metadata = meta)
term = _Term(Any, getindex, [x, idx...]; metadata = meta)
T = eltype(symtype(x))
N = ndims(x) - count(i -> symtype(i) <: Integer, idx)
res = ArrayOp(atype(symtype(x)){T,N},
Expand Down Expand Up @@ -197,12 +204,12 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast})
# then you get pairs, and index matcher cannot
# recurse into pairs
Atype = propagate_atype(broadcast, bc.f, args...)
args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args)
args = map(x -> x isa Base.RefValue ? _Term(Any, Ref, [x[]]) : x, args)
ArrayOp(Atype{symtype(expr),ndim},
(subscripts...,),
expr,
+,
Term{Any}(broadcast, [bc.f, args...]))
_Term(Any, broadcast, [bc.f, args...]))
end

# On wrapper:
Expand Down Expand Up @@ -270,15 +277,15 @@ function symeltype(A)
end
# TODO: add more such methods
function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...)
Term{symeltype(A)}(getindex, [A, i, ii...])
_Term(symeltype(A), getindex, [A, i, ii...])
end

function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer})
Term{symeltype(A)}(getindex, [A, i, j])
_Term(symeltype(A), getindex, [A, i, j])
end

function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int)
Term{symeltype(A)}(getindex, [A, j, i])
_Term(symeltype(A), getindex, [A, j, i])
end

function getindex(A::Arr, i::Int, j::Symbolic{<:Integer})
Expand Down Expand Up @@ -341,7 +348,7 @@ function _map(f, x, xs...)
(idx...,),
expr,
+,
Term{Any}(map, [f, x, xs...]))
_Term(Any, map, [f, x, xs...]))
end

@inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...)
Expand All @@ -359,15 +366,15 @@ end
expr = f(x[idx...])
T = symtype(g(expr, expr))
if dims === (:)
return Term{T}(_mapreduce, [f, g, x, dims, (kw...,)])
return _Term(T, _mapreduce, [f, g, x, dims, (kw...,)])
end

Atype = propagate_atype(_mapreduce, f, g, x, dims, (kw...,))
ArrayOp(Atype{T,ndims(x)},
(out_idx...,),
expr,
g,
Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)]))
_Term(Any, _mapreduce, [f, g, x, dims, (kw...,)]))
end false

for (ff, opts) in [sum => (identity, +, false),
Expand Down
16 changes: 8 additions & 8 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ macro arrayop(output_idx, expr, options...)
end |> esc
end

const SymArray = Union{ArrayOp, Symbolic{<:AbstractArray}}
const SymMat = Union{ArrayOp{<:AbstractMatrix}, Symbolic{<:AbstractMatrix}}
const SymVec = Union{ArrayOp{<:AbstractVector}, Symbolic{<:AbstractVector}}
const SymArray = Union{ArrayOp, BasicSymbolic{<:AbstractArray}}
const SymMat = Union{ArrayOp{<:AbstractMatrix}, BasicSymbolic{<:AbstractMatrix}}
const SymVec = Union{ArrayOp{<:AbstractVector}, BasicSymbolic{<:AbstractVector}}

### Propagate ###
#
Expand Down Expand Up @@ -415,7 +415,7 @@ function array_term(f, args...;
end
end
S = container_type{eltype, ndims}
setmetadata(Term{S}(f, Any[args...]), ArrayShapeCtx, shape)
setmetadata(_Term(S, f, Any[args...]), ArrayShapeCtx, shape)
end

"""
Expand Down Expand Up @@ -504,7 +504,7 @@ const ArrayLike{T,N} = Union{
ArrayOp{AbstractArray{T,N}},
Symbolic{AbstractArray{T,N}},
Arr{T,N},
SymbolicUtils.Term{AbstractArray{T, N}}
SymbolicUtils.BasicSymbolic{AbstractArray{T, N}}
} # Like SymArray but includes Arr and Term{Arr}

unwrap(x::Arr) = x.value
Expand Down Expand Up @@ -688,7 +688,7 @@ function scalarize_op(f::typeof(_det), arr)
end

@wrapped function LinearAlgebra.det(x::AbstractMatrix; laplace=true)
Term{eltype(x)}(_det, [x, laplace])
_Term(eltype(x), _det, [x, laplace])
end false


Expand Down Expand Up @@ -1055,7 +1055,7 @@ function get_inputs(x::ArrayOp)
end

function similar_arrayvar(ex, name)
Sym{symtype(ex)}(name) #TODO: shape?
_Sym(symtype(ex), name) #TODO: shape?
end

function reset_to_one(range)
Expand All @@ -1064,7 +1064,7 @@ function reset_to_one(range)
end

function reset_sym(i)
Sym{Int}(Symbol(nameof(i), "′"))
_Sym(Int, Symbol(nameof(i), "′"))
end

function inplace_expr(x::ArrayOp, outsym = :_out, intermediates = nothing)
Expand Down
6 changes: 3 additions & 3 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
oop_expr = wrap_code[1](oop_expr)
end

out = Sym{Any}(:ˍ₋out)
out = _Sym(Any, :ˍ₋out)
ip_body = if iip
postprocess_fbody(set_array(parallel,
dargs,
Expand Down Expand Up @@ -553,13 +553,13 @@ _set_array(out, outputidxs, rhs, checkbounds, skipzeros, cse) = rhs
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
vs_names = tosymbol.(vs)
for (v,k) in zip(vs_names, vs)
symsdict[k] = Sym{symtype(k)}(v)
symsdict[k] = _Sym(symtype(k), v)
end
exs = [:($name[$i]) for (i, u) ∈ enumerate(vs)]
vs_names,exs
end
function vars_to_pairs(name,vs, symsdict)
symsdict[vs] = Sym{symtype(vs)}(tosymbol(vs))
symsdict[vs] = _Sym(symtype(vs), tosymbol(vs))
[tosymbol(vs)], [name]
end

Expand Down
2 changes: 1 addition & 1 deletion src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function Base.show(io::IO, a::Complex{Num})
return print(io, arguments(rr)[1])
end

i = Sym{Real}(:im)
i = _Sym(Real, :im)
show(io, real(a) + i * imag(a))
end

Expand Down
8 changes: 4 additions & 4 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ function occursin_info(x, expr, fail = true)
if all(_isfalse, args)
return false
end
Term{Real}(true, args)
_Term(Real, true, args)
end
end

function occursin_info(x, expr::Sym, fail)
function occursin_info(x, expr::BasicSymbolic, fail)
if symtype(expr) <: AbstractArray && fail
error("Differentiation of expressions involving arrays and array variables is not yet supported.")
end
Expand Down Expand Up @@ -139,7 +139,7 @@ function recursive_hasoperator(op, O)
return true
else
if isadd(O) || ismul(O)
any(recursive_hasoperator(op), keys(O.dict))
any(recursive_hasoperator(op), keys(get_dict(O)))
elseif ispow(O)
recursive_hasoperator(op)(O.base) || recursive_hasoperator(op)(O.exp)
elseif isdiv(O)
Expand Down Expand Up @@ -636,7 +636,7 @@ end

isidx(x) = x isa TermCombination

basic_mkterm(t, g, args, m) = metadata(Term{Any}(g, args), m)
basic_mkterm(t, g, args, m) = metadata(_Term(Any, g, args), m)

let
# we do this in a let block so that Revise works on the list of rules
Expand Down
2 changes: 1 addition & 1 deletion src/difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Difference <: Operator
update::Bool
Difference(t; dt, update=false) = new(value(t), dt, update)
end
(D::Difference)(t) = Term{symtype(t)}(D, [t])
(D::Difference)(t) = _Term(symtype(t), D, [t])
(D::Difference)(t::Num) = Num(D(value(t)))
SymbolicUtils.promote_symtype(::Difference, t) = t
"""
Expand Down
2 changes: 1 addition & 1 deletion src/extra_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function _binomial(nothing, n, k)
end), unwrapped_args))
Base.binomial(unwrapped_args...)
else
SymbolicUtils.Term{Int}(Base.binomial, unwrapped_args)
_Term(Int, Base.binomial, unwrapped_args)
end
if typeof.(args) == typeof.(unwrapped_args)
return res
Expand Down
2 changes: 1 addition & 1 deletion src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ struct Integral{T <: Symbolics.VarDomainPairing} <: Function
Integral(domain) = new{typeof(domain)}(domain)
end

(I::Integral)(x) = Term{SymbolicUtils.symtype(x)}(I, [x])
(I::Integral)(x) = _Term(SymbolicUtils.symtype(x), I, [x])
(I::Integral)(x::Num) = Num(I(Symbolics.value(x)))
SymbolicUtils.promote_symtype(::Integral, x) = x

Expand Down
4 changes: 2 additions & 2 deletions src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function _toexpr(O)
denom = Any[]

# We need to iterate over each term in m, ignoring the numeric coefficient.
# This iteration needs to be stable, so we can't iterate over m.dict.
# This iteration needs to be stable, so we can't iterate over get_dict(m).
for term in Iterators.drop(sorted_arguments(m), isone(m.coeff) ? 0 : 1)
if !ispow(term)
push!(numer, _toexpr(term))
Expand Down Expand Up @@ -260,7 +260,7 @@ function diffdenom(e)
elseif ismul(e)
LaTeXString(prod(
"\\mathrm{d}$(k)$(isone(v) ? "" : "^{$v}")"
for (k, v) in e.dict
for (k, v) in get_dict(e)
))
else
e
Expand Down
2 changes: 1 addition & 1 deletion src/num.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Base.show(io::IO, n::Num) = show_numwrap[] ? print(io, :(Num($(value(n))))) : Ba
Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num
Base.promote_rule(::Type{BigFloat}, ::Type{<:Num}) = Num
Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num
function Base.getproperty(t::Union{Add, Mul, Pow, Term}, f::Symbol)
function Base.getproperty(t::BasicSymbolic, f::Symbol)
if f === :op
Base.depwarn("`x.op` is deprecated, use `operation(x)` instead", :getproperty, force=true)
operation(t)
Expand Down
2 changes: 1 addition & 1 deletion src/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function parse_expr_to_symbolic(ex, mod::Module)
else
x = parse_expr_to_symbolic(ex.args[1], mod)
ys = parse_expr_to_symbolic.(ex.args[2:end],(mod,))
return Term{Real}(x,[ys...])
return _Term(Real, x,[ys...])
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays =
res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args)
$f(unwrapped_args...) # partial-eval if all args are unwrapped
else
$Term{$ret_type}($f, unwrapped_args)
$_Term($ret_type, $f, unwrapped_args)
end
if typeof.(args) == typeof.(unwrapped_args)
return res
Expand Down Expand Up @@ -115,7 +115,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
elseif $ret_type == nothing || ($ret_type <: AbstractArray)
$array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...)
else
$Term{$ret_type}($f, unwrapped_args)
$_Term($ret_type, $f, unwrapped_args)
end

if typeof.(args) == typeof.(unwrapped_args)
Expand Down
10 changes: 5 additions & 5 deletions src/semipoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
# return a dictionary of exponents with respect to variables
function pdegrees(x)
if ismul(x)
return x.dict
return get_dict(x)
elseif isdiv(x)
num_dict = pdegrees(x.num)
den_dict = pdegrees(x.den)
Expand Down Expand Up @@ -136,7 +136,7 @@ Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff)
isop(x, op) = iscall(x) && operation(x) === op
isop(op) = Base.Fix2(isop, op)

simpleterm(T, f, args, m) = Term{SymbolicUtils._promote_symtype(f, args)}(f, args)
simpleterm(T, f, args, m) = _Term(SymbolicUtils._promote_symtype(f, args), f, args)

function mark_and_exponentiate(expr, vars)
# Step 1
Expand Down Expand Up @@ -197,16 +197,16 @@ function mark_vars(expr, vars)
if op === (^) || op == (/)
args = arguments(expr)
@assert length(args) == 2
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
return _Term(symtype(expr), op, map(mark_vars(vars), args))
end
args = arguments(expr)
if op === (+) || op === (*)
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
return _Term(symtype(expr), op, map(mark_vars(vars), args))
elseif length(args) == 1
if op == sqrt
return mark_vars(args[1]^(1//2), vars)
elseif linearity_1(op)
return Term{symtype(expr)}(op, mark_vars(args[1], vars))
return _Term(symtype(expr), op, mark_vars(args[1], vars))
end
end
return SemiMonomial(1, expr)
Expand Down
Loading
Loading