From 72c93f605d08b5710b169cf2e1ea6a4a2bbce706 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sun, 15 Sep 2024 19:40:50 -0500 Subject: [PATCH 01/10] Enzyme: adapt to pending version breaking update [only downstream] --- Project.toml | 2 +- ext/EnzymeCoreExt.jl | 250 ++++++++++++++++++++++++------------------- 2 files changed, 143 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index 21c476cbca..863829c5b6 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.15" ChainRulesCore = "1" Crayons = "4" DataFrames = "1" -EnzymeCore = "0.7.3" +EnzymeCore = "0.8" ExprTools = "0.1" GPUArrays = "10.0.1" GPUCompiler = "0.24, 0.25, 0.26, 0.27" diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index f8c8fe2c7d..44f44c4ce7 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -37,14 +37,14 @@ function metaf(fn, args::Vararg{Any, N}) where N nothing end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)}, ::Type{<:Duplicated}, f::Const{F}, tt::Const{TT}; kwargs...) where {F,TT} res = ofn.val(f.val, tt.val; kwargs...) return Duplicated(res, res) end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)}, ::Type{BatchDuplicated{T,N}}, f::Const{F}, tt::Const{TT}; kwargs...) where {F,TT,T,N} res = ofn.val(f.val, tt.val; kwargs...) @@ -53,24 +53,32 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, end) end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, x::IT) where {RT, IT} - if RT <: Duplicated - Duplicated(ofn.val(x.val), ofn.val(x.dval)) - elseif RT <: Const - ofn.val(x.val)::eltype(RT) - elseif RT <: DuplicatedNoNeed - ofn.val(x.val)::eltype(RT) - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta - ofn.val(x.dval[i])::eltype(RT) + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + Duplicated(ofn.val(x.val), ofn.val(x.dval)) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(x.dval[i])::eltype(RT) + end + BatchDuplicated(ofn.val(x.val), tup) end - if RT <: BatchDuplicated - BatchDuplicated(ofv.val(x.val), tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + ofn.val(x.dval)::eltype(RT) else - tup + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(x.dval[i])::eltype(RT) + end end + elseif EnzymeRules.needs_primal(config) + ofn.val(uval.val)::eltype(RT) + else + nothing end end @@ -93,7 +101,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cudac else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, tape, x::IT) where {RT, IT} @@ -101,64 +109,85 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, end -function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}}, ::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT} primargs = ntuple(Val(length(args))) do i Base.@_inline_meta args[i].val end - if RT <: Duplicated - shadow = ofn.val(uval.val, primargs...)::CT - fill!(shadow, 0) - Duplicated(ofn.val(uval.val, primargs...), shadow) - elseif RT <: Const - ofn.val(uval.val, primargs...) - elseif RT <: DuplicatedNoNeed - shadow = ofn.val(uval.val, primargs...)::CT - fill!(shadow, 0) - shadow::CT - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 shadow = ofn.val(uval.val, primargs...)::CT fill!(shadow, 0) - shadow::CT + Duplicated(ofn.val(uval.val, primargs...), shadow) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + shadow = ofn.val(uval.val, primargs...)::CT + fill!(shadow, 0) + shadow::CT + end + BatchDuplicated(ofn.val(uval.val, primargs...), tup) end - if RT <: BatchDuplicated - BatchDuplicated(ofv.val(uval.val), tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(uval.val, primargs...)::CT + fill!(shadow, 0) + shadow else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + shadow = ofn.val(uval.val, primargs...)::CT + fill!(shadow, 0) + shadow::CT + end tup end + elseif EnzymeRules.needs_primal(config) + ofn.val(uval.val, primargs...) + else + nothing end end -function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}}, ::Type{RT}, uval::EnzymeCore.Annotation{DR}, args...; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT} primargs = ntuple(Val(length(args))) do i Base.@_inline_meta args[i].val end - if RT <: Duplicated - shadow = ofn.val(uval.val, primargs...; kwargs...) - Duplicated(ofn.val(uval.dval, primargs...; kwargs...), shadow) - elseif RT <: Const - ofn.val(uval.val, primargs...; kwargs...) - elseif RT <: DuplicatedNoNeed - ofn.val(uval.dval, primargs...; kwargs...) - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta - shadow = ofn.val(uval.dval[i], primargs...; kwargs...) + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(uval.val, primargs...; kwargs...) + Duplicated(ofn.val(uval.val, primargs...; kwargs...), shadow) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(uval.val, primargs...; kwargs...) + end + BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup) end - if RT <: BatchDuplicated - BatchDuplicated(ofv.val(uval.val), tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(uval.val, primargs...; kwargs...) + shadow else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(uval.val, primargs...; kwargs...) + end tup end + elseif EnzymeRules.needs_primal(config) + ofn.val(uval.val, primargs...; kwargs...) + else + nothing end end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(synchronize)}, ::Type{RT}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {RT, N} pargs = ntuple(Val(N)) do i Base.@_inline_meta @@ -166,26 +195,33 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)}, end res = ofn.val(pargs...; kwargs...) - if RT <: Duplicated - return Duplicated(res, res) - elseif RT <: Const - return res - elseif RT <: DuplicatedNoNeed - return res - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta - res + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + Duplicated(res, res) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + res + end + BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup) end - if RT <: BatchDuplicated - return BatchDuplicated(res, tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + res else - return tup + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + res + end end + elseif EnzymeRules.needs_primal(config) + res + else + nothing end end -function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}}, +function EnzymeCore.EnzymeRules.forward(config, ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}}, ::Type{Const{Nothing}}, args...; kwargs...) where {F,TT} @@ -223,7 +259,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cufun else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cufunction)},::Type{RT}, subtape, f, tt; kwargs...) where RT @@ -350,7 +386,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. end end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes} +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes} if A isa Const || A isa Duplicated || A isa BatchDuplicated ofn.val(A.val, x.val) end @@ -365,16 +401,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{R end end - if RT <: Duplicated - return A - elseif RT <: Const - return A.val - elseif RT <: DuplicatedNoNeed - return A.dval - elseif RT <: BatchDuplicated - return A + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + A + elseif EnzymeRules.needs_shadow(config) + A.dval + elseif EnzymeRules.needs_primal(config) + A.val else - return A.dval + nothing end end @@ -469,7 +503,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, : else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{UndefInitializer}, args::Vararg{EnzymeCore.Annotation, N}) where {CT <: CuArray, RT, N} @@ -503,7 +537,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, : else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{DR}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT, N} @@ -517,7 +551,7 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...) return nothing end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays.mapreducedim!)}, ::Type{RT}, f::EnzymeCore.Const{typeof(Base.identity)}, op::EnzymeCore.Const{typeof(Base.add_sum)}, @@ -544,16 +578,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim end end - if RT <: Duplicated - return R - elseif RT <: Const - return R.val - elseif RT <: DuplicatedNoNeed - return R.dval - elseif RT <: BatchDuplicated - return R + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + R + elseif EnzymeRules.needs_shadow(config) + R.dval + elseif EnzymeRules.needs_primal(config) + R.val else - return R.dval + nothing end end @@ -605,34 +637,36 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapr return (nothing, nothing, nothing, nothing) end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays._mapreduce)}, ::Type{RT}, f::EnzymeCore.Const{typeof(Base.identity)}, op::EnzymeCore.Const{typeof(Base.add_sum)}, A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D} - if RT <: Const + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(f.val, op.val, A.dval; dims, init) + Duplicated(ofn.val(f.val, op.val, A.val; dims, init), shadow) + else + tup = ntuple(Val(EnzymeRules.batch_width(RT))) do i + Base.@_inline_meta + ofn.val(f.val, op.val, A.dval[i]; dims, init) + end + BatchDuplicated(ofn.val(f.val, op.val, A.val; dims, init), tup) + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + ofn.val(f.val, op.val, A.dval; dims, init) + else + ntuple(Val(EnzymeRules.batch_width(RT))) do i + Base.@_inline_meta + ofn.val(f.val, op.val, A.dval[i]; dims, init) + end + end + elseif EnzymeRules.needs_primal(config) ofn.val(f.val, op.val, A.val; dims, init) - elseif RT <: Duplicated - ( - ofn.val(f.val, op.val, A.val; dims, init), - ofn.val(f.val, op.val, A.dval; dims, init) - ) - elseif RT <: DuplicatedNoNeed - ofn.val(f.val, op.val, A.dval; dims, init) - elseif RT <: BatchDuplicated - ( - ofn.val(f.val, op.val, A.val; dims, init), - ntuple(Val(EnzymeRules.batch_width(RT))) do i - Base.@_inline_meta - ofn.val(f.val, op.val, A.dval[i]; dims, init) - end - ) else - @assert RT <: BatchDuplicatedNoNeed - ntuple(Val(EnzymeRules.batch_width(RT))) do i - Base.@_inline_meta - ofn.val(f.val, op.val, A.dval[i]; dims, init) - end + nothing end end From 89ed7d223897a398a58be5eb7397ae5846c2cc7e Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 17:21:25 -0500 Subject: [PATCH 02/10] Update EnzymeCoreExt.jl --- ext/EnzymeCoreExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index 44f44c4ce7..4075a78412 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -1,5 +1,4 @@ # compatibility with EnzymeCore - module EnzymeCoreExt using CUDA From 7899822145d1e7f8bed15a9d54cf9a143e6a500a Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 17:49:40 -0400 Subject: [PATCH 03/10] Fix [only downstream] --- Project.toml | 2 +- ext/EnzymeCoreExt.jl | 30 ++++++++++++++++-------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 863829c5b6..3c4246ca67 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.15" ChainRulesCore = "1" Crayons = "4" DataFrames = "1" -EnzymeCore = "0.8" +EnzymeCore = "0.8.1" ExprTools = "0.1" GPUArrays = "10.0.1" GPUCompiler = "0.24, 0.25, 0.26, 0.27" diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index 4075a78412..f445adc7d5 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -31,8 +31,8 @@ function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Ty return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device())) end -function metaf(fn, args::Vararg{Any, N}) where N - EnzymeCore.autodiff_deferred(Forward, fn, Const, args...) +function metaf(config, fn, args::Vararg{Any, N}) where N + EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), fn, Const, args...) nothing end @@ -226,10 +226,10 @@ function EnzymeCore.EnzymeRules.forward(config, ofn::EnzymeCore.Annotation{CUDA. GC.@preserve args begin args = ((cudaconvert(a) for a in args)...,) - T2 = (F, (typeof(a) for a in args)...) + T2 = (typeof(config), F, (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(metaf, TT2) - res = cuf(ofn.val.f, args...; kwargs...) + res = cuf(config, ofn.val.f, args...; kwargs...) end return nothing @@ -265,9 +265,10 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cuf return (nothing, nothing) end -function meta_augf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType} +function meta_augf(config, f, tape::CuDeviceArray{TapeType}, args::Vararg{Any, N}) where {N, TapeType} + ModifiedBetween = overwritten(config) forward, _ = EnzymeCore.autodiff_deferred_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, @@ -305,7 +306,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotat ModifiedBetween = overwritten(config) TapeType = EnzymeCore.tape_type( EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(Base.identity), Tuple{Float64}), - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), Const{F}, Const{Nothing}, map(typeof, args)..., @@ -316,18 +317,19 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotat GC.@preserve args subtape, begin subtape2 = cudaconvert(subtape) - T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...) + T2 = (typeof(config), F, typeof(subtape2), (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(meta_augf, TT2) - res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) + res = cuf(ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) end return AugmentedReturn{Nothing,Nothing,CuArray}(nothing, nothing, subtape) end -function meta_revf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType} +function meta_revf(config, f, tape::CuDeviceArray{TapeType}, args::Vararg{Any, N}) where {N, TapeType} + ModifiedBetween = overwritten(config) _, reverse = EnzymeCore.autodiff_deferred_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, @@ -363,7 +365,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. args = ((cudaconvert(arg) for arg in args0)...,) ModifiedBetween = overwritten(config) TapeType = EnzymeCore.tape_type( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), Const{F}, Const{Nothing}, map(typeof, args)..., @@ -373,10 +375,10 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. GC.@preserve args0 subtape, begin subtape2 = cudaconvert(subtape) - T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...) + T2 = (typeof(config), F, typeof(subtape2), (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(meta_revf, TT2) - res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) + res = cuf(ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) end return ntuple(Val(length(args0))) do i From 4526b527b09a8ea9ad057c951acb460a45800828 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 18:27:25 -0400 Subject: [PATCH 04/10] fix [only downstream] --- .buildkite/pipeline.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2b82534b31..472e35fb98 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -221,11 +221,16 @@ steps: # to check compatibility, also add Enzyme to the main environment # (or Pkg.test, which merges both environments, could fail) Pkg.activate(".") + # Try to co-develop Enzyme and KA, if that fails, try just to dev Enzyme try - Pkg.develop("Enzyme") + Pkg.develop([PackageSpec("Enzyme"), PackageSpec("KernelAbstractions")]) catch err - @error "Could not install Enzyme" exception=(err,catch_backtrace()) - exit(3) + try + Pkg.develop([PackageSpec("Enzyme")]) + catch err + @error "Could not install Enzyme" exception=(err,catch_backtrace()) + exit(3) + end end end From 2156063faf7b8918cbc0de21d09ffc01345bbfa9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 19 Sep 2024 19:14:30 -0400 Subject: [PATCH 05/10] fix fwd --- ext/EnzymeCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index f445adc7d5..21ffac12ec 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -32,7 +32,7 @@ function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Ty end function metaf(config, fn, args::Vararg{Any, N}) where N - EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), fn, Const, args...) + EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), Const(fn), Const, args...) nothing end From 7cbb095471cdaf63b998a36051ed9f73363c08bc Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 20 Sep 2024 15:45:01 -0500 Subject: [PATCH 06/10] Update EnzymeCoreExt.jl [only downstream] --- ext/EnzymeCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index 21ffac12ec..19749022df 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -320,7 +320,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotat T2 = (typeof(config), F, typeof(subtape2), (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(meta_augf, TT2) - res = cuf(ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) + res = cuf(config, ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) end return AugmentedReturn{Nothing,Nothing,CuArray}(nothing, nothing, subtape) @@ -378,7 +378,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. T2 = (typeof(config), F, typeof(subtape2), (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(meta_revf, TT2) - res = cuf(ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) + res = cuf(config, ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) end return ntuple(Val(length(args0))) do i From 2918a80f6df28dac5690826e905db2c602cfca84 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 20 Sep 2024 16:01:57 -0500 Subject: [PATCH 07/10] Update EnzymeCoreExt.jl --- ext/EnzymeCoreExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index 19749022df..003394420b 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -67,12 +67,12 @@ function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cudaconvert)}, end elseif EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 - ofn.val(x.dval)::eltype(RT) + ofn.val(x.dval)::EnzymeCore.shadow_type(config, RT) else - ntuple(Val(EnzymeRules.width(config))) do i + (ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta ofn.val(x.dval[i])::eltype(RT) - end + end)::EnzymeCore.shadow_type(config, RT) end elseif EnzymeRules.needs_primal(config) ofn.val(uval.val)::eltype(RT) From 64171898580ca0665c967a162e2ab10c441dc258 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 01:32:41 -0400 Subject: [PATCH 08/10] ix [only downstream] --- ext/EnzymeCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index 003394420b..51b6b20e38 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -133,7 +133,7 @@ function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}}, if EnzymeRules.width(config) == 1 shadow = ofn.val(uval.val, primargs...)::CT fill!(shadow, 0) - shadow + shadow::shadow_type(config, RT) else tup = ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta @@ -141,7 +141,7 @@ function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}}, fill!(shadow, 0) shadow::CT end - tup + tup::shadow_type(config, RT) end elseif EnzymeRules.needs_primal(config) ofn.val(uval.val, primargs...) From caff24d4dfde7397e95fccab7213f15f0f466793 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 02:31:05 -0500 Subject: [PATCH 09/10] Update Project.toml [only downstream] --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3c4246ca67..c4c834f77f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CUDA" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.5.0" +version = "5.5.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.15" ChainRulesCore = "1" Crayons = "4" DataFrames = "1" -EnzymeCore = "0.8.1" +EnzymeCore = "0.8.2" ExprTools = "0.1" GPUArrays = "10.0.1" GPUCompiler = "0.24, 0.25, 0.26, 0.27" From f3cb7310d79991c4f712b8731257dc72fc3cae1a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 21 Sep 2024 13:00:25 -0500 Subject: [PATCH 10/10] Update enzyme.jl [only downstream] --- test/extensions/enzyme.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/extensions/enzyme.jl b/test/extensions/enzyme.jl index 0ad29c91eb..093472ffd3 100644 --- a/test/extensions/enzyme.jl +++ b/test/extensions/enzyme.jl @@ -78,10 +78,10 @@ end alloc(x) = CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}(undef, (x,)) @testset "Forward allocate" begin - dup = Enzyme.autodiff(Forward, alloc, Duplicated, Const(10)) - @test all(dup[2] .≈ 0.0) + dup = Enzyme.autodiff(ForwardWithPrimal, alloc, Duplicated, Const(10)) + @test all(dup[1] .≈ 0.0) - dup = Enzyme.autodiff(Forward, alloc, DuplicatedNoNeed, Const(10)) + dup = Enzyme.autodiff(Forward, alloc, Duplicated, Const(10)) @test all(dup[1] .≈ 0.0) end