diff --git a/src/code.jl b/src/code.jl index 4e1589ed9..9eaac38e6 100644 --- a/src/code.jl +++ b/src/code.jl @@ -10,6 +10,7 @@ import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, symtype, sorted_arguments, metadata, isterm, term, maketerm +import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==## @@ -169,6 +170,14 @@ function substitute_name(O, st) end end +function _is_array_of_symbolics(O) + # O is an array, not a symbolic array, and either has a non-symbolic eltype or contains elements that are + # symbolic or arrays of symbolics + return O isa AbstractArray && symbolic_type(O) == NotSymbolic() && + (symbolic_type(eltype(O)) != NotSymbolic() || + any(x -> symbolic_type(x) != NotSymbolic() || _is_array_of_symbolics(x), O)) +end + function toexpr(O, st) if issym(O) O = substitute_name(O, st) @@ -176,6 +185,9 @@ function toexpr(O, st) end O = substitute_name(O, st) + if _is_array_of_symbolics(O) + return toexpr(MakeArray(O, typeof(O)), st) + end !iscall(O) && return O op = operation(O) expr′ = function_to_expr(op, O, st) diff --git a/test/code.jl b/test/code.jl index 946fde10c..7956aa59b 100644 --- a/test/code.jl +++ b/test/code.jl @@ -219,4 +219,12 @@ nanmath_st.rewrites[:nanmath] = true @test s1 == s2 end end + + let + @syms a b + + t = term(sum, [a, b, a + b, 3a + 2b, sqrt(b)]; type = Number) + f = eval(toexpr(Func([a, b], [], t))) + @test f(1.0, 2.0) ≈ 13.0 + sqrt(2) + end end