Skip to content

Commit

Permalink
Fix Term construction
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenszhu committed Sep 27, 2024
1 parent d5d5d58 commit 79f524d
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import TermInterface: maketerm, iscall, operation, arguments, metadata

import SymbolicUtils: BasicSymbolic, FnType, @rule, Rewriters, substitute, symtype,
promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, _Sym,
get_dict
_Term, get_dict

using SymbolicUtils.Code

Expand Down
22 changes: 11 additions & 11 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function Base.getindex(x::SymArray, idx...)
throw(BoundsError(x, idx))
end
end
res = Term{eltype(symtype(x))}(getindex, [x, Tuple(ii)...]; metadata = meta)
res = _Term(eltype(symtype(x)), getindex, [x, Tuple(ii)...]; metadata = meta)
elseif all(i -> symtype(i) <: Integer, idx)
shape(x) !== Unknown() && @boundscheck begin
if length(idx) > 1
Expand All @@ -43,7 +43,7 @@ function Base.getindex(x::SymArray, idx...)
end
end
end
res = Term{eltype(symtype(x))}(getindex, [x, idx...]; metadata = meta)
res = _Term(eltype(symtype(x)), getindex, [x, idx...]; metadata = meta)
elseif length(idx) == 1 && symtype(first(idx)) <: CartesianIndex
i = first(idx)
ii = i isa CartesianIndex ? Tuple(i) : arguments(i)
Expand All @@ -70,7 +70,7 @@ function Base.getindex(x::SymArray, idx...)
end
end

term = Term{Any}(getindex, [x, idx...]; metadata = meta)
term = _Term(Any, getindex, [x, idx...]; metadata = meta)
T = eltype(symtype(x))
N = ndims(x) - count(i -> symtype(i) <: Integer, idx)
res = ArrayOp(atype(symtype(x)){T,N},
Expand Down Expand Up @@ -197,12 +197,12 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast})
# then you get pairs, and index matcher cannot
# recurse into pairs
Atype = propagate_atype(broadcast, bc.f, args...)
args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args)
args = map(x -> x isa Base.RefValue ? _Term(Any, Ref, [x[]]) : x, args)
ArrayOp(Atype{symtype(expr),ndim},
(subscripts...,),
expr,
+,
Term{Any}(broadcast, [bc.f, args...]))
_Term(Any, broadcast, [bc.f, args...]))
end

# On wrapper:
Expand Down Expand Up @@ -270,15 +270,15 @@ function symeltype(A)
end
# TODO: add more such methods
function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...)
Term{symeltype(A)}(getindex, [A, i, ii...])
_Term(symeltype(A), getindex, [A, i, ii...])
end

function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer})
Term{symeltype(A)}(getindex, [A, i, j])
_Term(symeltype(A), getindex, [A, i, j])
end

function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int)
Term{symeltype(A)}(getindex, [A, j, i])
_Term(symeltype(A), getindex, [A, j, i])
end

function getindex(A::Arr, i::Int, j::Symbolic{<:Integer})
Expand Down Expand Up @@ -341,7 +341,7 @@ function _map(f, x, xs...)
(idx...,),
expr,
+,
Term{Any}(map, [f, x, xs...]))
_Term(Any, map, [f, x, xs...]))
end

@inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...)
Expand All @@ -359,15 +359,15 @@ end
expr = f(x[idx...])
T = symtype(g(expr, expr))
if dims === (:)
return Term{T}(_mapreduce, [f, g, x, dims, (kw...,)])
return _Term(T, _mapreduce, [f, g, x, dims, (kw...,)])
end

Atype = propagate_atype(_mapreduce, f, g, x, dims, (kw...,))
ArrayOp(Atype{T,ndims(x)},
(out_idx...,),
expr,
g,
Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)]))
_Term(Any, _mapreduce, [f, g, x, dims, (kw...,)]))
end false

for (ff, opts) in [sum => (identity, +, false),
Expand Down
4 changes: 2 additions & 2 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ function array_term(f, args...;
end
end
S = container_type{eltype, ndims}
setmetadata(Term{S}(f, Any[args...]), ArrayShapeCtx, shape)
setmetadata(_Term(S, f, Any[args...]), ArrayShapeCtx, shape)
end

"""
Expand Down Expand Up @@ -680,7 +680,7 @@ function scalarize_op(f::typeof(_det), arr)
end

@wrapped function LinearAlgebra.det(x::AbstractMatrix; laplace=true)
Term{eltype(x)}(_det, [x, laplace])
_Term(eltype(x), _det, [x, laplace])
end false


Expand Down
4 changes: 2 additions & 2 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function occursin_info(x, expr, fail = true)
if all(_isfalse, args)
return false
end
Term{Real}(true, args)
_Term(Real, true, args)
end
end

Expand Down Expand Up @@ -636,7 +636,7 @@ end

isidx(x) = x isa TermCombination

basic_mkterm(t, g, args, m) = metadata(Term{Any}(g, args), m)
basic_mkterm(t, g, args, m) = metadata(_Term(Any, g, args), m)

let
# we do this in a let block so that Revise works on the list of rules
Expand Down
2 changes: 1 addition & 1 deletion src/difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Difference <: Operator
update::Bool
Difference(t; dt, update=false) = new(value(t), dt, update)
end
(D::Difference)(t) = Term{symtype(t)}(D, [t])
(D::Difference)(t) = _Term(symtype(t), D, [t])
(D::Difference)(t::Num) = Num(D(value(t)))
SymbolicUtils.promote_symtype(::Difference, t) = t
"""
Expand Down
2 changes: 1 addition & 1 deletion src/extra_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function _binomial(nothing, n, k)
end), unwrapped_args))
Base.binomial(unwrapped_args...)
else
SymbolicUtils.Term{Int}(Base.binomial, unwrapped_args)
_Term(Int, Base.binomial, unwrapped_args)
end
if typeof.(args) == typeof.(unwrapped_args)
return res
Expand Down
2 changes: 1 addition & 1 deletion src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ struct Integral{T <: Symbolics.VarDomainPairing} <: Function
Integral(domain) = new{typeof(domain)}(domain)
end

(I::Integral)(x) = Term{SymbolicUtils.symtype(x)}(I, [x])
(I::Integral)(x) = _Term(SymbolicUtils.symtype(x), I, [x])
(I::Integral)(x::Num) = Num(I(Symbolics.value(x)))
SymbolicUtils.promote_symtype(::Integral, x) = x

Expand Down
2 changes: 1 addition & 1 deletion src/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function parse_expr_to_symbolic(ex, mod::Module)
else
x = parse_expr_to_symbolic(ex.args[1], mod)
ys = parse_expr_to_symbolic.(ex.args[2:end],(mod,))
return Term{Real}(x,[ys...])
return _Term(Real, x,[ys...])
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays =
res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args)
$f(unwrapped_args...) # partial-eval if all args are unwrapped
else
$Term{$ret_type}($f, unwrapped_args)
$_Term($ret_type, $f, unwrapped_args)
end
if typeof.(args) == typeof.(unwrapped_args)
return res
Expand Down Expand Up @@ -115,7 +115,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
elseif $ret_type == nothing || ($ret_type <: AbstractArray)
$array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...)
else
$Term{$ret_type}($f, unwrapped_args)
$_Term($ret_type, $f, unwrapped_args)
end

if typeof.(args) == typeof.(unwrapped_args)
Expand Down
8 changes: 4 additions & 4 deletions src/semipoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff)
isop(x, op) = iscall(x) && operation(x) === op
isop(op) = Base.Fix2(isop, op)

simpleterm(T, f, args, m) = Term{SymbolicUtils._promote_symtype(f, args)}(f, args)
simpleterm(T, f, args, m) = _Term(SymbolicUtils._promote_symtype(f, args), f, args)

function mark_and_exponentiate(expr, vars)
# Step 1
Expand Down Expand Up @@ -197,16 +197,16 @@ function mark_vars(expr, vars)
if op === (^) || op == (/)
args = arguments(expr)
@assert length(args) == 2
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
return _Term(symtype(expr), op, map(mark_vars(vars), args))
end
args = arguments(expr)
if op === (+) || op === (*)
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
return _Term(symtype(expr), op, map(mark_vars(vars), args))
elseif length(args) == 1
if op == sqrt
return mark_vars(args[1]^(1//2), vars)
elseif linearity_1(op)
return Term{symtype(expr)}(op, mark_vars(args[1], vars))
return _Term(symtype(expr), op, mark_vars(args[1], vars))
end
end
return SemiMonomial(1, expr)
Expand Down
4 changes: 2 additions & 2 deletions test/macro.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using Symbolics
import Symbolics: getsource, getdefaultval, wrap, unwrap, getname
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype
import SymbolicUtils: symtype, FnType, BasicSymbolic, promote_symtype, _Term
using LinearAlgebra
using Test

@variables t
Symbolics.@register_symbolic fff(t)
@test isequal(fff(t), Symbolics.Num(Symbolics.Term{Real}(fff, [Symbolics.value(t)])))
@test isequal(fff(t), Symbolics.Num(_Term(Real, fff, [Symbolics.value(t)])))

const SymMatrix{T,N} = Symmetric{T, AbstractArray{T, N}}
many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4]
Expand Down
4 changes: 2 additions & 2 deletions test/overloads.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Symbolics: Sym, FnType, Term, value, scalarize
using Symbolics: FnType, _Term, value, scalarize
using Symbolics
using LinearAlgebra
using SparseArrays: sparse
Expand Down Expand Up @@ -163,7 +163,7 @@ z2 = c + d * im
@test conj(a) === a
@test imag(a) === Num(0)

@test isequal(sign(x), Num(SymbolicUtils.Term{Int}(sign, [Symbolics.value(x)])))
@test isequal(sign(x), Num(_Term(Int, sign, [Symbolics.value(x)])))
@test sign(Num(1)) isa Num
@test isequal(sign(Num(1)), Num(1))
@test isequal(sign(Num(-1)), Num(-1))
Expand Down

0 comments on commit 79f524d

Please sign in to comment.