From 2a19b7ad17e2f40b19c490d83eb43f4edb8be23a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 14:52:47 -0700 Subject: [PATCH 01/15] feat: more coverage for common NN activations --- ext/ReactantNNlibExt.jl | 39 +++++++++++++++++++++++++++------------ src/overloads.jl | 2 ++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 7f935141..b09df773 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,23 +3,38 @@ module ReactantNNlibExt using NNlib using Reactant -for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh)) +for (jlop, hloop) in ( + (:(NNlib.tanh_fast), :tanh), + (:(NNlib.sigmoid_fast), :logistic), + (:(NNlib.sigmoid), :logistic), +) @eval begin - if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast - function Reactant.elem_apply( - ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} - ) where {ElType,Shape,N} - return Reactant.TracedRArray{ElType,Shape,N}( - (), - Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 - ), - ) - end + function Reactant.elem_apply( + ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + return Reactant.TracedRArray{ElType,Shape,N}( + (), + Reactant.MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 + ), + ) end end end +function Reactant.elem_apply( + ::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N} +) where {ElType,Shape,N} + return (lhs .> zero(ElType)) .* lhs # base case uses ifelse, so we compile the product +end + +function Reactant.elem_apply( + ::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N} +) where {ElType,Shape,N} + # See https://arxiv.org/pdf/1606.08415v5 Section 2 + return lhs .* sigmoid.(ElType(1.702) .* lhs) +end + # TODO handle non finite cases function NNlib.softmax!( out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1 diff --git a/src/overloads.jl b/src/overloads.jl index 43b73f66..593d6e22 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -81,6 +81,8 @@ for (jlop, hloop, RT) in ( end end +Base.abs2(x::Reactant.TracedRArray{T,(),0}) where {T} = x * x + function Base.literal_pow( ::Base.RefValue{typeof(^)}, x::Reactant.TracedRArray{T,(),0}, ::Base.RefValue{Val{P}} ) where {T,P} From e72b02f64a42b5a1b614b12c4d89f0e012085e22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 15:09:14 -0700 Subject: [PATCH 02/15] feat: support `mean`. --- Project.toml | 8 ++------ src/Reactant.jl | 1 + src/overloads.jl | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index e0d6f99a..a4e4b0d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,6 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" -authors = [ - "William Moses ", - "Valentin Churavy ", - "Sergio Sánchez Ramírez ", - "Paul Berg ", -] +authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg "] version = "0.1.8" [deps] @@ -15,6 +10,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/Reactant.jl b/src/Reactant.jl index dfca1390..73d4fa39 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,6 +1,7 @@ module Reactant using PackageExtensionCompat +using Statistics: Statistics include("mlir/MLIR.jl") include("XLA.jl") diff --git a/src/overloads.jl b/src/overloads.jl index 593d6e22..5d85c638 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -130,6 +130,17 @@ for (jlop, hloop, RT) in ( ) end + # Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity + function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + function $jlop(lhs, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape} lhs = promote_to(rhs, lhs) return TracedRArray{$RT,Shape,0}( @@ -188,6 +199,11 @@ for (jlop, hloop) in ( end end +function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} + denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) + return mapreduce(identity, +, A; dims) / denom +end + function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true From d40476754d6abe71f8d358685b44cb9cb1d86c8c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 15:20:55 -0700 Subject: [PATCH 03/15] feat: support `var`. --- src/overloads.jl | 61 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/src/overloads.jl b/src/overloads.jl index 5d85c638..4df35bc5 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -59,24 +59,43 @@ for (jlop, hloop, RT) in ( ) end - function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} - rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}( + function $jlop( + lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0} + ) where {ElType} + return TracedRArray{$RT,(),0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 ), ) end + end - function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 - ), - ) + for otherType in (Number, Any, TracedRArray{S,(),0} where {S}) + @eval begin + function $jlop( + lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType + ) where {ElType,Shape,N} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + + function $jlop( + lhs::$otherType, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end end end end @@ -199,11 +218,6 @@ for (jlop, hloop) in ( end end -function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} - denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) - return mapreduce(identity, +, A; dims) / denom -end - function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true @@ -650,3 +664,18 @@ function Base.mapreducedim!(f, op, R::TracedRArray, A::Base.AbstractArrayOrBroad R.mlir_data = elem_apply(op, R, tmp).mlir_data return R end + + +# Stdlib overloads +## Statistics +function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} + denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) + return mapreduce(identity, +, A; dims) / denom +end +function Statistics.var( + A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true +) where {T,Shape,N} + mean === nothing && (mean = Statistics.mean(A; dims)) + denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + return mapreduce(abs2, +, A .- mean; dims) / denom +end From 03dc542f4f477f7b68addcbceededa7c042087a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 16:06:20 -0700 Subject: [PATCH 04/15] feat: add overload for `ifelse` --- ext/ReactantNNlibExt.jl | 2 +- src/overloads.jl | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index b09df773..40fd7819 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -25,7 +25,7 @@ end function Reactant.elem_apply( ::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N} ) where {ElType,Shape,N} - return (lhs .> zero(ElType)) .* lhs # base case uses ifelse, so we compile the product + return ifelse.((lhs .> zero(ElType)), lhs, zero(ElType)) end function Reactant.elem_apply( diff --git a/src/overloads.jl b/src/overloads.jl index 4df35bc5..80e8f058 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -172,6 +172,20 @@ for (jlop, hloop, RT) in ( end end +function elem_apply( + ::typeof(Base.ifelse), + pred::TracedRArray{Bool,Shape,N}, + x::TracedRArray{ElType1,Shape,N}, + y::TracedRArray{ElType2,Shape,N}, +) where {ElType1,ElType2,Shape,N} + return TracedRArray{promote_type(ElType1, ElType2),Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 + ), + ) +end + function Base.:*( lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Shape2,2} ) where {ElType,Shape,Shape2} @@ -665,7 +679,6 @@ function Base.mapreducedim!(f, op, R::TracedRArray, A::Base.AbstractArrayOrBroad return R end - # Stdlib overloads ## Statistics function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} From 54d8db673580bb4598ce1df313c6dbcbc5ed07d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 16:06:43 -0700 Subject: [PATCH 05/15] chore: relax compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a4e4b0d9..c29da595 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ NNlib = "0.9" PackageExtensionCompat = "1" Preferences = "1.4" Reactant_jll = "0.0.14" +Statistics = "1.9" julia = "1.9" [extras] From 0032f43541d5eebf65cadd478c7af615b56de71b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 17:43:57 -0700 Subject: [PATCH 06/15] test: activation functions and their adjoints --- test/Project.toml | 1 + test/bcast.jl | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 10bc878e..6a5bf901 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/bcast.jl b/test/bcast.jl index 9d05200a..bec6facd 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -1,6 +1,6 @@ using Reactant - +using Enzyme, NNlib using Reactant.MLIR @noinline function no(@nospecialize(x)) @@ -56,3 +56,38 @@ function test() end end test() + +@testset "Activation Functions" begin + sumabs2(f, x) = sum(abs2, f.(x)) + + function ∇sumabs2(f, x) + dx = Enzyme.make_zero(x) + Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) + return dx + end + + x_act = randn(Float32, 10, 10) + x_act_ca = Reactant.ConcreteRArray(x_act) + + @testset "Activation: $act" for act in ( + identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu + ) + f_compile = Reactant.compile(sumabs2, (act, x_act)) + + y_simple = sumabs2(act, x_act) + y_compile = f_compile(act, x_act_ca) + + if act !== relu + ∂x_enz = Enzyme.make_zero(x_act) + Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) + + ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) + + ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) + + @test y_simple ≈ y_compile + else + @test_broken Reactant.compile(∇sumabs2, (act, x_act_ca)) isa Any + end + end +end From 3a09389259d89d90526003c68d0b411ffd099113 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 18:07:57 -0700 Subject: [PATCH 07/15] test: `mean` and `var` --- src/Reactant.jl | 4 ++-- test/basic.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 73d4fa39..c8ab789d 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -57,13 +57,13 @@ function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y; kwargs...) where {ElTy end function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElType} - return Base.isapprox(to_float(x), y; kwargs...) + return Base.isapprox(x, to_float(y); kwargs...) end function Base.isapprox( x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs... ) where {ElType,ElType2} - return Base.isapprox(to_float(x), y; kwargs...) + return Base.isapprox(to_float(x), to_float(y); kwargs...) end function Base.print_array(io::IO, X::ConcreteRArray) diff --git a/test/basic.jl b/test/basic.jl index 07870de9..199ed4bd 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -152,3 +152,38 @@ end @test contains(res_repr, "stablehlo.dot_general") end + +@testset "Statistics: `mean` & `var`" begin + x = randn(2, 3, 4) + x_ca = Reactant.ConcreteRArray(x) + + mean_fn1(x) = mean(x) + mean_fn2(x) = mean(x; dims=1) + mean_fn3(x) = mean(x; dims=(1, 2)) + mean_fn4(x) = mean(x; dims=(1, 3)) + + mean_fn1_compiled = Reactant.compile(mean_fn1, (x_ca,)) + mean_fn2_compiled = Reactant.compile(mean_fn2, (x_ca,)) + mean_fn3_compiled = Reactant.compile(mean_fn3, (x_ca,)) + mean_fn4_compiled = Reactant.compile(mean_fn4, (x_ca,)) + + @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) + @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) + @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) + @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) + + var_fn1(x) = var(x) + var_fn2(x) = var(x; dims=1) + var_fn3(x) = var(x; dims=(1, 2), corrected=false) + var_fn4(x) = var(x; dims=(1, 3), corrected=false) + + var_fn1_compiled = Reactant.compile(var_fn1, (x_ca,)) + var_fn2_compiled = Reactant.compile(var_fn2, (x_ca,)) + var_fn3_compiled = Reactant.compile(var_fn3, (x_ca,)) + var_fn4_compiled = Reactant.compile(var_fn4, (x_ca,)) + + @test var_fn1(x) ≈ var_fn1_compiled(x_ca) + @test var_fn2(x) ≈ var_fn2_compiled(x_ca) + @test var_fn3(x) ≈ var_fn3_compiled(x_ca) + @test var_fn4(x) ≈ var_fn4_compiled(x_ca) +end From 8a90efaa5fa518b126c6a01c9de0dad253211992 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 18:17:42 -0700 Subject: [PATCH 08/15] test: add BatchNorm to the lux test --- test/basic.jl | 1 + test/nn_lux.jl | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index 199ed4bd..20ef425d 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,6 +1,7 @@ using Reactant using Test using Enzyme +using Statistics # Reactant.set_default_backend("gpu") diff --git a/test/nn_lux.jl b/test/nn_lux.jl index e9bf1b20..3521efc3 100644 --- a/test/nn_lux.jl +++ b/test/nn_lux.jl @@ -9,6 +9,7 @@ truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-ele # Define our model, a multi-layer perceptron with one hidden layer of size 3: model = Lux.Chain( Lux.Dense(2 => 3, tanh), # activation function inside layer + Lux.BatchNorm(3, gelu), Lux.Dense(3 => 2), softmax, ) @@ -17,8 +18,7 @@ ps, st = Lux.setup(Xoshiro(123), model) using BenchmarkTools origout, _ = model(noisy, ps, st) -@show origout[3] -@btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB) +@btime model($noisy, $ps, $st) # 68.444 μs (46 allocations: 45.88 KiB) cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete) cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete) @@ -31,8 +31,9 @@ f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cs # # @show @code_typed f(cmodel,cnoisy) # # @show @code_llvm f(cmodel,cnoisy) comp = f(cmodel, cnoisy, cps, cst) -@show comp[3] -@btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes) +@btime f($cmodel, $cnoisy, $cps, $cst) # 21.790 μs (6 allocations: 224 bytes) + +@test comp ≈ origout atol = 1e-5 rtol = 1e-2 # To train the model, we use batches of 64 samples, and one-hot encoding: @@ -81,6 +82,8 @@ compiled_gradient = Reactant.compile( gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst) ) +@test length(compiled_gradient(cmodel, cnoisy, ctarget, cps, cst)) == 2 + # # Training loop, using the whole data set 1000 times: # losses = [] # for epoch in 1:1_000 From a12a95f6012ca9e3555d945ce35112eedefdac06 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 09:52:44 -0700 Subject: [PATCH 09/15] fix: update `relu` and `abs2` --- ext/ReactantNNlibExt.jl | 2 +- src/overloads.jl | 2 +- test/bcast.jl | 16 ++++++---------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 40fd7819..72e467c9 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -25,7 +25,7 @@ end function Reactant.elem_apply( ::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N} ) where {ElType,Shape,N} - return ifelse.((lhs .> zero(ElType)), lhs, zero(ElType)) + return max.(lhs, zero(ElType)) end function Reactant.elem_apply( diff --git a/src/overloads.jl b/src/overloads.jl index 80e8f058..7fb4f0c8 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -100,7 +100,7 @@ for (jlop, hloop, RT) in ( end end -Base.abs2(x::Reactant.TracedRArray{T,(),0}) where {T} = x * x +Base.abs2(x::Reactant.TracedRArray{T,(),0}) where {T} = x * conj(x) function Base.literal_pow( ::Base.RefValue{typeof(^)}, x::Reactant.TracedRArray{T,(),0}, ::Base.RefValue{Val{P}} diff --git a/test/bcast.jl b/test/bcast.jl index bec6facd..f4942e2e 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -70,24 +70,20 @@ test() x_act_ca = Reactant.ConcreteRArray(x_act) @testset "Activation: $act" for act in ( - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu + identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 ) f_compile = Reactant.compile(sumabs2, (act, x_act)) y_simple = sumabs2(act, x_act) y_compile = f_compile(act, x_act_ca) - if act !== relu - ∂x_enz = Enzyme.make_zero(x_act) - Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) + ∂x_enz = Enzyme.make_zero(x_act) + Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) - ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) + ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) - ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) + ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) - @test y_simple ≈ y_compile - else - @test_broken Reactant.compile(∇sumabs2, (act, x_act_ca)) isa Any - end + @test y_simple ≈ y_compile end end From 2a67dd0558b048f8e1ad4d93662cec3d0849c177 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 10:05:40 -0700 Subject: [PATCH 10/15] fix: dispatch directly on `ifelse` --- src/overloads.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/overloads.jl b/src/overloads.jl index 7fb4f0c8..42bd9955 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -172,13 +172,10 @@ for (jlop, hloop, RT) in ( end end -function elem_apply( - ::typeof(Base.ifelse), - pred::TracedRArray{Bool,Shape,N}, - x::TracedRArray{ElType1,Shape,N}, - y::TracedRArray{ElType2,Shape,N}, -) where {ElType1,ElType2,Shape,N} - return TracedRArray{promote_type(ElType1, ElType2),Shape,N}( +function Base.ifelse( + pred::TracedRArray{Bool,(),0}, x::TracedRArray{T1,(),0}, y::TracedRArray{T2,(),0} +) where {T1,T2} + return TracedRArray{promote_type(T1, T2),(),0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 From 6c5ecaf47b48af4742e31e6c9c9b918a879ac304 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 12:41:41 -0700 Subject: [PATCH 11/15] test: skip Lux tests pre-1.9 --- test/runtests.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 77ecc063..bf98d3a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,4 +46,7 @@ include("nn.jl") include("struct.jl") include("closure.jl") include("compile.jl") -include("nn_lux.jl") + +if VERSION ≥ v"1.10-" # Lux isn't supported on 1.9 + include("nn_lux.jl") +end From d5d75f3c9f14034bbc726944055b71791256f7e3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 1 Aug 2024 17:15:02 -0700 Subject: [PATCH 12/15] fix: overload scalar ops --- ext/ReactantNNlibExt.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 72e467c9..ca2e3056 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -22,18 +22,9 @@ for (jlop, hloop) in ( end end -function Reactant.elem_apply( - ::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N} -) where {ElType,Shape,N} - return max.(lhs, zero(ElType)) -end +NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T)) -function Reactant.elem_apply( - ::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N} -) where {ElType,Shape,N} - # See https://arxiv.org/pdf/1606.08415v5 Section 2 - return lhs .* sigmoid.(ElType(1.702) .* lhs) -end +NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x) # TODO handle non finite cases function NNlib.softmax!( From 3d93c667000872df27762e302ecc82e4494465e4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 14:51:03 -0700 Subject: [PATCH 13/15] refactor: move statistics into extension --- Project.toml | 3 ++- ext/ReactantStatisticsExt.jl | 19 +++++++++++++++++++ src/Reactant.jl | 1 - src/overloads.jl | 14 -------------- 4 files changed, 21 insertions(+), 16 deletions(-) create mode 100644 ext/ReactantStatisticsExt.jl diff --git a/Project.toml b/Project.toml index c29da595..590ac2dc 100644 --- a/Project.toml +++ b/Project.toml @@ -10,17 +10,18 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] ReactantAdaptExt = "Adapt" ReactantArrayInterfaceExt = "ArrayInterface" ReactantNNlibExt = "NNlib" +ReactantStatisticsExt = "Statistics" [compat] Adapt = "4" diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl new file mode 100644 index 00000000..2dd813ae --- /dev/null +++ b/ext/ReactantStatisticsExt.jl @@ -0,0 +1,19 @@ +module ReactantStatisticsExt + +using Reactant: TracedRArray +using Statistics: Statistics + +function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} + denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) + return mapreduce(identity, +, A; dims) / denom +end + +function Statistics.var( + A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true +) where {T,Shape,N} + mean === nothing && (mean = Statistics.mean(A; dims)) + denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + return mapreduce(abs2, +, A .- mean; dims) / denom +end + +end diff --git a/src/Reactant.jl b/src/Reactant.jl index c8ab789d..7c70b8cb 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,7 +1,6 @@ module Reactant using PackageExtensionCompat -using Statistics: Statistics include("mlir/MLIR.jl") include("XLA.jl") diff --git a/src/overloads.jl b/src/overloads.jl index 42bd9955..7fb0fb90 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -675,17 +675,3 @@ function Base.mapreducedim!(f, op, R::TracedRArray, A::Base.AbstractArrayOrBroad R.mlir_data = elem_apply(op, R, tmp).mlir_data return R end - -# Stdlib overloads -## Statistics -function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} - denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) - return mapreduce(identity, +, A; dims) / denom -end -function Statistics.var( - A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true -) where {T,Shape,N} - mean === nothing && (mean = Statistics.mean(A; dims)) - denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected - return mapreduce(abs2, +, A .- mean; dims) / denom -end From d0e1dbde338e5eca22df5df0d87eeca01caa2eeb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 14:54:15 -0700 Subject: [PATCH 14/15] refactor: remove more elem_apply --- ext/ReactantNNlibExt.jl | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index ca2e3056..23763e4c 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -8,17 +8,13 @@ for (jlop, hloop) in ( (:(NNlib.sigmoid_fast), :logistic), (:(NNlib.sigmoid), :logistic), ) - @eval begin - function Reactant.elem_apply( - ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} - ) where {ElType,Shape,N} - return Reactant.TracedRArray{ElType,Shape,N}( - (), - Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 - ), - ) - end + @eval function $(jlop)(x::Reactant.TracedRArray{T,(),0}) where {T} + return Reactant.TracedRArray{T,(),0}( + (), + Reactant.MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 + ), + ) end end From 791d1b573e1e765df2943a2fa77ef34fefaa93e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 Aug 2024 15:38:57 -0700 Subject: [PATCH 15/15] fix: ambiguity error --- src/overloads.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/overloads.jl b/src/overloads.jl index 7fb0fb90..7f98627d 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -149,6 +149,16 @@ for (jlop, hloop, RT) in ( ) end + function $jlop(lhs, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + # Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape} rhs = promote_to(lhs, rhs) @@ -160,7 +170,7 @@ for (jlop, hloop, RT) in ( ) end - function $jlop(lhs, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape} + function $jlop(lhs::Number, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape} lhs = promote_to(rhs, lhs) return TracedRArray{$RT,Shape,0}( (),