From 906395ea5bdb739928cc47e51ad538d293aceca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Tue, 16 Jul 2024 15:16:08 +0200 Subject: [PATCH] Replace Cassette for an overlayed `MethodTable` (#40) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replace Cassette with an overlayed `MethodTable` * Fix typo in `ReactantInterpreter` constructor * Rename `get_interence_world` to `get_world_counter` * Complete `AbstractInterpreter` API implementation * Refactor phrase * Run traced function over custom interpreter * Fix `OpaqueClosure` wrapping of interpreted closures * Add comment * Add `@hlo_override` macro * Format code * Rename `@hlo_override` to `@reactant_override` * Fix interpreter for Julia 1.11 * Add `Symbol` case for `make_tracer` * Try fix ghost argument in `OpaqueClosure` * Fix argtypes of `OpaqueClosure` on closures with Julia 1.9 * Replace `*` for `sum` in `@code_hlo` test * Fix return type type-unstability in `Base.:*` * Refactor includes --------- Co-authored-by: Sergio Sánchez Ramírez --- Project.toml | 2 - README.md | 2 +- src/Interpreter.jl | 419 +++++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 1 + src/overloads.jl | 368 +-------------------------------------- src/utils.jl | 34 +++- 6 files changed, 454 insertions(+), 372 deletions(-) create mode 100644 src/Interpreter.jl diff --git a/Project.toml b/Project.toml index ac7a3ac0..4290635b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ version = "0.1.7" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" -Cassette = "7057c7e9-c182-5462-911a-8362d720325c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" @@ -31,7 +30,6 @@ ReactantNNlibExt = "NNlib" Adapt = "4" ArrayInterface = "7.10" CEnum = "0.4, 0.5" -Cassette = "0.3" Enzyme = "0.11, 0.12" NNlib = "0.9" PackageExtensionCompat = "1" diff --git a/README.md b/README.md index 0ddaf6e3..676d71c1 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ > [!WARNING] > This package is under active development at the moment and may change its API and supported end systems at any time. End-users are advised to wait until a corresponding release with broader availability is made. Package developers are suggested to try out Reactant for integration, but should be advised of the same risks. -Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system based off of Cassette. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort. This system and corresponding semantics is subject to change to a (potentially partial) source rewriter in the future. +Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort. This system and corresponding semantics is subject to change to a (potentially partial) source rewriter in the future. Reactant provides two new array types at its core, a ConcreteRArray and a TracedRArray. A ConcreteRArray is an underlying buffer to whatever device data you wish to store and can be created by converting from a regular Julia Array. diff --git a/src/Interpreter.jl b/src/Interpreter.jl new file mode 100644 index 00000000..4c12dc93 --- /dev/null +++ b/src/Interpreter.jl @@ -0,0 +1,419 @@ +# Taken from https://github.com/JuliaLang/julia/pull/52964/files#diff-936d33e524bcd097015043bd6410824119be5c210d43185c4d19634eb4912708 +# Other references: +# - https://github.com/JuliaLang/julia/blob/0fd1f04dc7d4b905b0172b7130e9b1beab9bc4c9/test/compiler/AbstractInterpreter.jl#L228-L234 +# - https://github.com/JuliaLang/julia/blob/v1.10.4/test/compiler/newinterp.jl#L9 + +const CC = Core.Compiler +using Enzyme + +const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552" + +Base.Experimental.@MethodTable(ReactantMethodTable) + +macro reactant_override(expr) + return :(Base.Experimental.@overlay ReactantMethodTable $(expr)) +end + +@static if !HAS_INTEGRATED_CACHE + struct ReactantCache + dict::IdDict{Core.MethodInstance,Core.CodeInstance} + end + ReactantCache() = ReactantCache(IdDict{Core.MethodInstance,Core.CodeInstance}()) + + function CC.get(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance, default) + return get(wvc.cache.dict, mi, default) + end + function CC.getindex(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance) + return getindex(wvc.cache.dict, mi) + end + function CC.haskey(wvc::CC.WorldView{ReactantCache}, mi::Core.MethodInstance) + return haskey(wvc.cache.dict, mi) + end + function CC.setindex!( + wvc::CC.WorldView{ReactantCache}, ci::Core.CodeInstance, mi::Core.MethodInstance + ) + return setindex!(wvc.cache.dict, ci, mi) + end +end + +struct ReactantInterpreter <: CC.AbstractInterpreter + world::UInt + inf_params::CC.InferenceParams + opt_params::CC.OptimizationParams + inf_cache::Vector{CC.InferenceResult} + @static if !HAS_INTEGRATED_CACHE + code_cache::ReactantCache + end + + @static if HAS_INTEGRATED_CACHE + function ReactantInterpreter(; + world::UInt=Base.get_world_counter(), + inf_params::CC.InferenceParams=CC.InferenceParams(), + opt_params::CC.OptimizationParams=CC.OptimizationParams(), + inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], + ) + return new(world, inf_params, opt_params, inf_cache) + end + else + function ReactantInterpreter(; + world::UInt=Base.get_world_counter(), + inf_params::CC.InferenceParams=CC.InferenceParams(), + opt_params::CC.OptimizationParams=CC.OptimizationParams(), + inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], + code_cache=ReactantCache(), + ) + return new(world, inf_params, opt_params, inf_cache, code_cache) + end + end +end + +@static if HAS_INTEGRATED_CACHE + CC.get_inference_world(interp::ReactantInterpreter) = interp.world +else + CC.get_world_counter(interp::ReactantInterpreter) = interp.world +end + +CC.InferenceParams(interp::ReactantInterpreter) = interp.inf_params +CC.OptimizationParams(interp::ReactantInterpreter) = interp.opt_params +CC.get_inference_cache(interp::ReactantInterpreter) = interp.inf_cache + +@static if HAS_INTEGRATED_CACHE + # TODO what does this do? taken from https://github.com/JuliaLang/julia/blob/v1.11.0-rc1/test/compiler/newinterp.jl + @eval CC.cache_owner(interp::ReactantInterpreter) = + $(QuoteNode(gensym(:ReactantInterpreterCache))) +else + function CC.code_cache(interp::ReactantInterpreter) + return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world)) + end +end + +function CC.method_table(interp::ReactantInterpreter) + return CC.OverlayMethodTable(interp.world, ReactantMethodTable) +end + +const enzyme_out = 0 +const enzyme_dup = 1 +const enzyme_const = 2 +const enzyme_dupnoneed = 3 +const enzyme_outnoneed = 4 +const enzyme_constnoneed = 5 + +@inline act_from_type(x, reverse, needs_primal=true) = + throw(AssertionError("Unhandled activity $(typeof(x))")) +@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = + act_from_type(Enzyme.Const, reverse, needs_primal) +@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = + act_from_type(Enzyme.Duplicated, reverse, needs_primal) +@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) = + reverse ? enzyme_out : enzyme_dupnoneed +@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = + act_from_tuple(Enzyme.Active, reverse, needs_primal) +@inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) = + if needs_primal + enzyme_const + else + enzyme_constnoneed + end +@inline act_from_type(::Type{<:Enzyme.Duplicated}, reverse, needs_primal) = + if reverse + if needs_primal + enzyme_out + else + enzyme_outnoneed + end + else + if needs_primal + enzyme_dup + else + enzyme_dupnoneed + end + end +@inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) = + if needs_primal + enzyme_out + else + enzyme_outnoneed + end + +function push_val!(ad_inputs, x, path) + for p in path + x = getfield(x, p) + end + x = x.mlir_data + return push!(ad_inputs, x) +end + +function push_acts!(ad_inputs, x::Const, path, reverse) + return push_val!(ad_inputs, x.val, path) +end + +function push_acts!(ad_inputs, x::Active, path, reverse) + return push_val!(ad_inputs, x.val, path) +end + +function push_acts!(ad_inputs, x::Duplicated, path, reverse) + push_val!(ad_inputs, x.val, path) + if !reverse + push_val!(ad_inputs, x.dval, path) + end +end + +function push_acts!(ad_inputs, x::DuplicatedNoNeed, path, reverse) + push_val!(ad_inputs, x.val, path) + if !reverse + push_val!(ad_inputs, x.dval, path) + end +end + +function set_act!(inp, path, reverse, tostore; emptypath=false) + x = if inp isa Enzyme.Active + inp.val + else + inp.dval + end + + for p in path + x = getfield(x, p) + end + + #if inp isa Enzyme.Active || !reverse + x.mlir_data = tostore + #else + # x.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(x.mlir_data, tostore), 1) + #end + + if emptypath + x.paths = () + end +end + +function set!(x, path, tostore; emptypath=false) + for p in path + x = getfield(x, p) + end + + x.mlir_data = tostore + + if emptypath + x.paths = () + end +end + +function get_argidx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :args + return path[2]::Int, path + end + end + throw(AssertionError("No path found for $x")) +end +function get_residx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :result + return path + end + end + throw(AssertionError("No path found $x")) +end + +function has_residx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :result + return true + end + end + return false +end + +function get_attribute_by_name(operation, name) + return MLIR.IR.Attribute(MLIR.API.mlirOperationGetAttributeByName(operation, name)) +end + +@reactant_override function Enzyme.autodiff( + ::CMode, f::FA, ::Type{A}, args::Vararg{Enzyme.Annotation,Nargs} +) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs} + reverse = CMode <: Enzyme.ReverseMode + + primf = f.val + primargs = ((v.val for v in args)...,) + + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + primf, primargs, (), string(f) * "_autodiff", false + ) + + activity = Int32[] + ad_inputs = MLIR.IR.Value[] + + for a in linear_args + idx, path = get_argidx(a) + if idx == 1 && fnwrap + push!(activity, act_from_type(f, reverse)) + push_acts!(ad_inputs, f, path[3:end], reverse) + else + if fnwrap + idx -= 1 + end + push!(activity, act_from_type(args[idx], reverse)) + push_acts!(ad_inputs, args[idx], path[3:end], reverse) + end + end + + outtys = MLIR.IR.Type[] + @inline needs_primal(::Type{<:Enzyme.ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = + ReturnPrimal + for a in linear_results + if has_residx(a) + if needs_primal(CMode) + push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + end + else + push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + end + end + for (i, act) in enumerate(activity) + if act == enzyme_out || (reverse && (act == enzyme_dup || act == enzyme_dupnoneed)) + push!(outtys, in_tys[i])# transpose_ty(MLIR.IR.type(MLIR.IR.operand(ret, i)))) + end + end + + ret_activity = Int32[] + for a in linear_results + if has_residx(a) + act = act_from_type(A, reverse, needs_primal(CMode)) + push!(ret_activity, act) + if act == enzyme_out || act == enzyme_outnoneed + attr = fill(MLIR.IR.Attribute(eltype(a)(1)), mlir_type(a)) + cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + push!(ad_inputs, cst) + end + else + idx, path = get_argidx(a) + if idx == 1 && fnwrap + act = act_from_type(f, reverse, true) + push!(ret_activity, act) + if act != enzyme_out && act != enzyme_outnoneed + continue + end + push_val!(ad_inputs, f.dval, path[3:end]) + else + if fnwrap + idx -= 1 + end + act = act_from_type(args[idx], reverse, true) + push!(ret_activity, act) + if act != enzyme_out && act != enzyme_outnoneed + continue + end + push_val!(ad_inputs, args[idx].dval, path[3:end]) + end + end + end + + function act_attr(val) + val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, val::Int32 + )::MLIR.API.MlirAttribute + return MLIR.IR.Attribute(val) + end + fname = get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( + [transpose_val(v) for v in ad_inputs]; + outputs=outtys, + fn=fname, + activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), + ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), + ) + + residx = 1 + + for a in linear_results + if has_residx(a) + if needs_primal(CMode) + path = get_residx(a) + set!(result, path[2:end], transpose_val(MLIR.IR.result(res, residx))) + residx += 1 + end + else + idx, path = get_argidx(a) + if idx == 1 && fnwrap + set!(f.val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) + residx += 1 + else + if fnwrap + idx -= 1 + end + set!(args[idx].val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) + residx += 1 + end + end + end + + restup = Any[(a isa Active) ? copy(a) : nothing for a in args] + for a in linear_args + idx, path = get_argidx(a) + if idx == 1 && fnwrap + if act_from_type(f, reverse) != enzyme_out + continue + end + if f isa Enzyme.Active + @assert false + residx += 1 + continue + end + set_act!(f, path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx))) + else + if fnwrap + idx -= 1 + end + if act_from_type(args[idx], reverse) != enzyme_out + continue + end + if args[idx] isa Enzyme.Active + set_act!( + args[idx], + path[3:end], + false, + transpose_val(MLIR.IR.result(res, residx)); + emptypaths=true, + ) #=reverse=# + residx += 1 + continue + end + set_act!( + args[idx], path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx)) + ) + end + residx += 1 + end + + func2.operation = MLIR.API.MlirOperation(C_NULL) + + if reverse + resv = if needs_primal(CMode) + result + else + nothing + end + return ((restup...,), resv) + else + if A <: Const + return result + else + dres = copy(result) + throw(AssertionError("TODO implement forward mode handler")) + if A <: Duplicated + return () + end + end + end +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 5441cebd..ef919818 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -4,6 +4,7 @@ using PackageExtensionCompat include("mlir/MLIR.jl") include("XLA.jl") +include("Interpreter.jl") include("utils.jl") abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType,N} end diff --git a/src/overloads.jl b/src/overloads.jl index 79255b8d..43b73f66 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -1,343 +1,3 @@ - -using Cassette - -using Enzyme - -Cassette.@context TraceCtx; - -const enzyme_out = 0 -const enzyme_dup = 1 -const enzyme_const = 2 -const enzyme_dupnoneed = 3 -const enzyme_outnoneed = 4 -const enzyme_constnoneed = 5 - -function get_argidx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :args - return path[2]::Int, path - end - end - throw(AssertionError("No path found for $x")) -end -function get_residx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :result - return path - end - end - throw(AssertionError("No path found $x")) -end - -function has_residx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :result - return true - end - end - return false -end - -@inline act_from_type(x, reverse, needs_primal=true) = - throw(AssertionError("Unhandled activity $(typeof(x))")) -@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = - act_from_type(Enzyme.Const, reverse, needs_primal) -@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = - act_from_type(Enzyme.Duplicated, reverse, needs_primal) -@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) = - reverse ? enzyme_out : enzyme_dupnoneed -@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = - act_from_tuple(Enzyme.Active, reverse, needs_primal) - -@inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) = - if needs_primal - enzyme_const - else - enzyme_constnoneed - end -@inline act_from_type(::Type{<:Enzyme.Duplicated}, reverse, needs_primal) = - if reverse - if needs_primal - enzyme_out - else - enzyme_outnoneed - end - else - if needs_primal - enzyme_dup - else - enzyme_dupnoneed - end - end -@inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) = - if needs_primal - enzyme_out - else - enzyme_outnoneed - end - -function push_val!(ad_inputs, x, path) - for p in path - x = getfield(x, p) - end - x = x.mlir_data - return push!(ad_inputs, x) -end - -function push_acts!(ad_inputs, x::Const, path, reverse) - return push_val!(ad_inputs, x.val, path) -end - -function push_acts!(ad_inputs, x::Active, path, reverse) - return push_val!(ad_inputs, x.val, path) -end - -function push_acts!(ad_inputs, x::Duplicated, path, reverse) - push_val!(ad_inputs, x.val, path) - if !reverse - push_val!(ad_inputs, x.dval, path) - end -end - -function push_acts!(ad_inputs, x::DuplicatedNoNeed, path, reverse) - push_val!(ad_inputs, x.val, path) - if !reverse - push_val!(ad_inputs, x.dval, path) - end -end - -function set_act!(inp, path, reverse, tostore; emptypath=false) - x = if inp isa Enzyme.Active - inp.val - else - inp.dval - end - - for p in path - x = getfield(x, p) - end - - #if inp isa Enzyme.Active || !reverse - x.mlir_data = tostore - #else - # x.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(x.mlir_data, tostore), 1) - #end - - if emptypath - x.paths = () - end -end - -function set!(x, path, tostore; emptypath=false) - for p in path - x = getfield(x, p) - end - - x.mlir_data = tostore - - if emptypath - x.paths = () - end -end - -function get_attribute_by_name(operation, name) - return MLIR.IR.Attribute(MLIR.API.mlirOperationGetAttributeByName(operation, name)) -end - -function Cassette.overdub( - ::TraceCtx, - ::typeof(Enzyme.autodiff), - ::CMode, - f::FA, - ::Type{A}, - args::Vararg{Enzyme.Annotation,Nargs}, -) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs} - reverse = CMode <: Enzyme.ReverseMode - - primf = f.val - primargs = ((v.val for v in args)...,) - - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( - primf, primargs, (), string(f) * "_autodiff", false - ) - - activity = Int32[] - ad_inputs = MLIR.IR.Value[] - - for a in linear_args - idx, path = get_argidx(a) - if idx == 1 && fnwrap - push!(activity, act_from_type(f, reverse)) - push_acts!(ad_inputs, f, path[3:end], reverse) - else - if fnwrap - idx -= 1 - end - push!(activity, act_from_type(args[idx], reverse)) - push_acts!(ad_inputs, args[idx], path[3:end], reverse) - end - end - - outtys = MLIR.IR.Type[] - @inline needs_primal(::Type{<:Enzyme.ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = - ReturnPrimal - for a in linear_results - if has_residx(a) - if needs_primal(CMode) - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) - end - else - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) - end - end - for (i, act) in enumerate(activity) - if act == enzyme_out || (reverse && (act == enzyme_dup || act == enzyme_dupnoneed)) - push!(outtys, in_tys[i])# transpose_ty(MLIR.IR.type(MLIR.IR.operand(ret, i)))) - end - end - - ret_activity = Int32[] - for a in linear_results - if has_residx(a) - act = act_from_type(A, reverse, needs_primal(CMode)) - push!(ret_activity, act) - if act == enzyme_out || act == enzyme_outnoneed - attr = fill(MLIR.IR.Attribute(eltype(a)(1)), mlir_type(a)) - cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - push!(ad_inputs, cst) - end - else - idx, path = get_argidx(a) - if idx == 1 && fnwrap - act = act_from_type(f, reverse, true) - push!(ret_activity, act) - if act != enzyme_out && act != enzyme_outnoneed - continue - end - push_val!(ad_inputs, f.dval, path[3:end]) - else - if fnwrap - idx -= 1 - end - act = act_from_type(args[idx], reverse, true) - push!(ret_activity, act) - if act != enzyme_out && act != enzyme_outnoneed - continue - end - push_val!(ad_inputs, args[idx].dval, path[3:end]) - end - end - end - - function act_attr(val) - val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet( - MLIR.IR.context()::MLIR.API.MlirContext, val::Int32 - )::MLIR.API.MlirAttribute - return MLIR.IR.Attribute(val) - end - fname = get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( - [transpose_val(v) for v in ad_inputs]; - outputs=outtys, - fn=fname, - activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), - ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), - ) - - residx = 1 - - for a in linear_results - if has_residx(a) - if needs_primal(CMode) - path = get_residx(a) - set!(result, path[2:end], transpose_val(MLIR.IR.result(res, residx))) - residx += 1 - end - else - idx, path = get_argidx(a) - if idx == 1 && fnwrap - set!(f.val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) - residx += 1 - else - if fnwrap - idx -= 1 - end - set!(args[idx].val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) - residx += 1 - end - end - end - - restup = Any[(a isa Active) ? copy(a) : nothing for a in args] - for a in linear_args - idx, path = get_argidx(a) - if idx == 1 && fnwrap - if act_from_type(f, reverse) != enzyme_out - continue - end - if f isa Enzyme.Active - @assert false - residx += 1 - continue - end - set_act!(f, path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx))) - else - if fnwrap - idx -= 1 - end - if act_from_type(args[idx], reverse) != enzyme_out - continue - end - if args[idx] isa Enzyme.Active - set_act!( - args[idx], - path[3:end], - false, - transpose_val(MLIR.IR.result(res, residx)); - emptypaths=true, - ) #=reverse=# - residx += 1 - continue - end - set_act!( - args[idx], path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx)) - ) - end - residx += 1 - end - - func2.operation = MLIR.API.MlirOperation(C_NULL) - - if reverse - resv = if needs_primal(CMode) - result - else - nothing - end - return ((restup...,), resv) - else - if A <: Const - return result - else - dres = copy(result) - throw(AssertionError("TODO implement forward mode handler")) - if A <: Duplicated - return () - end - end - end -end - function promote_to(::Type{TracedRArray{ElType,Shape,N}}, rhs) where {ElType,Shape,N} if isa(rhs, TracedRArray) return TracedRArray{ElType,Shape,N}( @@ -480,8 +140,6 @@ for (jlop, hloop, RT) in ( end end -Cassette.overdub(context::TraceCtx, f::typeof(Enzyme.make_zero), args...) = f(args...) - function Base.:*( lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Shape2,2} ) where {ElType,Shape,Shape2} @@ -505,11 +163,9 @@ function Base.:*( ), 1, ) - return TracedRArray{ElType,(Base.size(lhsty)[1], Base.size(rhsty)[2]),2}((), res) + return TracedRArray{ElType,(Shape[1], Shape2[2]),2}((), res) end -Cassette.overdub(context::TraceCtx, f::typeof(Base.:*), args...) = f(args...) - for (jlop, hloop) in ( (:(Base.:-), :negate), (:(Base.sin), :sine), @@ -527,7 +183,6 @@ for (jlop, hloop) in ( (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) ) end - Cassette.overdub(context::TraceCtx, f::typeof($jlop), args...) = f(args...) end end @@ -672,12 +327,9 @@ for (jlop, hloop, hlocomp, RT) in ( end end -Cassette.overdub(context::TraceCtx, f::typeof(elem_apply), args...) = f(args...) - @inline function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) return reshape(A, Base._reshape_uncolon(A, dims)) end -Cassette.overdub(context::TraceCtx, f::typeof(Base.reshape), args...) = f(args...) @inline function Base.reshape( A::ConcreteRArray{T,Shape,N}, dims::NTuple{NT,Int} @@ -696,7 +348,6 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.reshape), args...) = f(args.. end Base.copy(A::TracedRArray{T,Shape,N}) where {T,Shape,N} = TracedRArray((), A.mlir_data) -Cassette.overdub(context::TraceCtx, f::typeof(Base.copy), args...) = f(args...) @inline function Base.permutedims(A::TracedRArray{T,Shape,N}, perm) where {T,Shape,N} return TracedArray{T,tuple(Shape[i] for i in perm),N}( @@ -709,7 +360,6 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.copy), args...) = f(args...) ), ) end -Cassette.overdub(context::TraceCtx, f::typeof(Base.permutedims), args...) = f(args...) @inline function Base.reshape( A::TracedRArray{T,Shape,N}, dims::NTuple{NT,Int} @@ -767,7 +417,6 @@ BroadcastStyle(::Type{T}) where {T<:TracedRArray} = AbstractReactantArrayStyle{n function Base.similar(x::TracedRArray{T,Shape,N}, ::Type{T2}) where {T,Shape,N,T2} return TracedRArray{T2,Shape,N}((), nothing) end -Cassette.overdub(context::TraceCtx, f::typeof(Base.similar), args...) = f(args...) @inline function Base.similar( bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims @@ -817,7 +466,6 @@ end ) where {Style<:AbstractReactantArrayStyle} return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) end -Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(args...) @inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict @@ -851,7 +499,6 @@ function Base.fill!(A::TracedRArray{T,Shape,N}, x) where {T,Shape,N} A.mlir_data = bcast.mlir_data return A end -Cassette.overdub(context::TraceCtx, f::typeof(Base.fill!), args...) = f(args...) @inline function broadcast_to_size(arg::T, rsize) where {T<:Number} TT = MLIR.IR.TensorType([Int64(s) for s in rsize], MLIR.IR.Type(typeof(arg))) @@ -910,18 +557,6 @@ end return dest end -function Cassette.overdub( - context::Cassette.Context, - ::Core.kwftype(typeof(Base.mapreduce)), - kwargs::Any, - ::typeof(Base.mapreduce), - args..., -) - return Base.mapreduce(args...; kwargs...) -end - -Cassette.overdub(context::Cassette.Context, f::typeof(Base.mapreduce), args...) = f(args...) - function Base.mapreduce( f, op, A::TracedRArray{ElType,Shape,N}; dims=:, init=nothing ) where {ElType,Shape,N} @@ -997,4 +632,3 @@ function Base.mapreducedim!(f, op, R::TracedRArray, A::Base.AbstractArrayOrBroad R.mlir_data = elem_apply(op, R, tmp).mlir_data return R end -Cassette.overdub(context::TraceCtx, f::typeof(Base.mapreducedim!), args...) = f(args...) diff --git a/src/utils.jl b/src/utils.jl index 5203b3db..771c23cf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -71,9 +71,39 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa arg.mlir_data = row_maj_arg end - return Cassette.overdub( - Cassette.disablehooks(TraceCtx()), f, traced_args...; kwargs... + # TODO replace with `Base.invoke_within` if julia#52964 lands + ir = first( + only( + # TODO fix it for kwargs + Base.code_ircode(f, map(typeof, traced_args); interp=ReactantInterpreter()), + ), ) + + # NOTE on Julia 1.9, it appends a ghost argument at the end + # solution: manually specify argument types + @static if VERSION < v"1.10" + empty!(ir.argtypes) + if f === Reactant.apply + append!( + ir.argtypes, + Any[ + Core.Const(f), + typeof(traced_args[1]), + Tuple{typeof.(traced_args[2:end])...}, + ], + ) + else + append!(ir.argtypes, Any[Core.Const(f), typeof.(traced_args)...]) + end + end + + oc = Core.OpaqueClosure(ir) + + if f === Reactant.apply + oc(traced_args[1], (traced_args[2:end]...,)) + else + oc(traced_args...) + end end seen_results = IdDict()