Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: more coverage for common NN operations #55

Merged
merged 15 commits into from
Aug 5, 2024
10 changes: 4 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = [
"William Moses <wmoses@mit.edu>",
"Valentin Churavy <vchuravy@mit.edu>",
"Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>",
"Paul Berg <paul@plutojl.org>",
]
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>"]
version = "0.1.8"

[deps]
Expand All @@ -20,11 +15,13 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
ReactantAdaptExt = "Adapt"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"

[compat]
Adapt = "4"
Expand All @@ -35,6 +32,7 @@ NNlib = "0.9"
PackageExtensionCompat = "1"
Preferences = "1.4"
Reactant_jll = "0.0.14"
Statistics = "1.9"
julia = "1.9"

[extras]
Expand Down
30 changes: 16 additions & 14 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,25 @@ module ReactantNNlibExt
using NNlib
using Reactant

for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh))
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(
::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return Reactant.TracedRArray{ElType,Shape,N}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1
),
)
end
end
for (jlop, hloop) in (
(:(NNlib.tanh_fast), :tanh),
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::Reactant.TracedRArray{T,(),0}) where {T}
return Reactant.TracedRArray{T,(),0}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
)
end
end

NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T))

NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x)

# TODO handle non finite cases
function NNlib.softmax!(
out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1
Expand Down
19 changes: 19 additions & 0 deletions ext/ReactantStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module ReactantStatisticsExt

using Reactant: TracedRArray
using Statistics: Statistics

function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N}
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
return mapreduce(identity, +, A; dims) / denom
end

function Statistics.var(
A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true
) where {T,Shape,N}
mean === nothing && (mean = Statistics.mean(A; dims))
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
return mapreduce(abs2, +, A .- mean; dims) / denom
end

end
4 changes: 2 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y; kwargs...) where {ElTy
end

function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElType}
return Base.isapprox(to_float(x), y; kwargs...)
return Base.isapprox(x, to_float(y); kwargs...)
end

function Base.isapprox(
x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs...
) where {ElType,ElType2}
return Base.isapprox(to_float(x), y; kwargs...)
return Base.isapprox(to_float(x), to_float(y); kwargs...)
end

function Base.print_array(io::IO, X::ConcreteRArray)
Expand Down
75 changes: 64 additions & 11 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,49 @@ for (jlop, hloop, RT) in (
)
end

function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
function $jlop(
lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0}
) where {ElType}
return TracedRArray{$RT,(),0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end

function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
for otherType in (Number, Any, TracedRArray{S,(),0} where {S})
@eval begin
function $jlop(
lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType
) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function $jlop(
lhs::$otherType, rhs::TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end
end
end

Base.abs2(x::Reactant.TracedRArray{T,(),0}) where {T} = x * conj(x)

function Base.literal_pow(
::Base.RefValue{typeof(^)}, x::Reactant.TracedRArray{T,(),0}, ::Base.RefValue{Val{P}}
) where {T,P}
Expand Down Expand Up @@ -137,9 +158,41 @@ for (jlop, hloop, RT) in (
),
)
end

# Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity
function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function $jlop(lhs::Number, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end
end

function Base.ifelse(
pred::TracedRArray{Bool,(),0}, x::TracedRArray{T1,(),0}, y::TracedRArray{T2,(),0}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this generalize to any shape/size, not just 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't the broadcasting handle the shape automatically?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but someone could also call ifelse(true, ones(4,4), zeros(4,4)) or ifelse(trues(4,4), ones(4,4), zeros(4,4)), etc, outside a broadcast [tho yes the 0 dim one will generalize to anything in a broadcast]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I don't think the latter case is legal in julia atm, so just generalizing to ifelse(true, ones(4,4), zeros(4,4)) probably makes sense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is wrong with this version I defined:

julia> f(x) = ifelse.(true, x, x)
f (generic function with 1 method)

julia> Reactant.@code_hlo optimize=false f(x)
Module:
module {
  func.func private @ifelse_broadcast_scalar(%arg0: tensor<i1>, %arg1: tensor<f64>) -> (tensor<i1>, tensor<f64>, tensor<f64>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<i1>) -> tensor<i1>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f64>) -> tensor<f64>
    %2 = stablehlo.select %0, %1, %1 : tensor<i1>, tensor<f64>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<i1>) -> tensor<i1>
    %4 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
    %5 = stablehlo.transpose %2, dims = [] : (tensor<f64>) -> tensor<f64>
    return %3, %4, %5 : tensor<i1>, tensor<f64>, tensor<f64>
  }
  func.func @main(%arg0: tensor<3x2xf64>) -> (tensor<3x2xf64>, tensor<3x2xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x2xf64>) -> tensor<2x3xf64>
    %c = stablehlo.constant dense<true> : tensor<2x3xi1>
    %1:3 = enzyme.batch @ifelse_broadcast_scalar(%c, %0) {batch_shape = array<i64: 2, 3>} : (tensor<2x3xi1>, tensor<2x3xf64>) -> (tensor<2x3xi1>, tensor<2x3xf64>, tensor<2x3xf64>)
    %2 = stablehlo.transpose %1#2, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    %3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %2, %3 : tensor<3x2xf64>, tensor<3x2xf64>
  }
}

julia> Reactant.@code_hlo f(x)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<3x2xf64>) {
    return
  }
}

) where {T1,T2}
return TracedRArray{promote_type(T1, T2),(),0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
),
)
end

function Base.:*(
lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Shape2,2}
) where {ElType,Shape,Shape2}
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
36 changes: 36 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Reactant
using Test
using Enzyme
using Statistics

# Reactant.set_default_backend("gpu")

Expand Down Expand Up @@ -152,3 +153,38 @@ end

@test contains(res_repr, "stablehlo.dot_general")
end

@testset "Statistics: `mean` & `var`" begin
x = randn(2, 3, 4)
x_ca = Reactant.ConcreteRArray(x)

mean_fn1(x) = mean(x)
mean_fn2(x) = mean(x; dims=1)
mean_fn3(x) = mean(x; dims=(1, 2))
mean_fn4(x) = mean(x; dims=(1, 3))

mean_fn1_compiled = Reactant.compile(mean_fn1, (x_ca,))
mean_fn2_compiled = Reactant.compile(mean_fn2, (x_ca,))
mean_fn3_compiled = Reactant.compile(mean_fn3, (x_ca,))
mean_fn4_compiled = Reactant.compile(mean_fn4, (x_ca,))

@test mean_fn1(x) ≈ mean_fn1_compiled(x_ca)
@test mean_fn2(x) ≈ mean_fn2_compiled(x_ca)
@test mean_fn3(x) ≈ mean_fn3_compiled(x_ca)
@test mean_fn4(x) ≈ mean_fn4_compiled(x_ca)

var_fn1(x) = var(x)
var_fn2(x) = var(x; dims=1)
var_fn3(x) = var(x; dims=(1, 2), corrected=false)
var_fn4(x) = var(x; dims=(1, 3), corrected=false)

var_fn1_compiled = Reactant.compile(var_fn1, (x_ca,))
var_fn2_compiled = Reactant.compile(var_fn2, (x_ca,))
var_fn3_compiled = Reactant.compile(var_fn3, (x_ca,))
var_fn4_compiled = Reactant.compile(var_fn4, (x_ca,))

@test var_fn1(x) ≈ var_fn1_compiled(x_ca)
@test var_fn2(x) ≈ var_fn2_compiled(x_ca)
@test var_fn3(x) ≈ var_fn3_compiled(x_ca)
@test var_fn4(x) ≈ var_fn4_compiled(x_ca)
end
33 changes: 32 additions & 1 deletion test/bcast.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

using Reactant

using Enzyme, NNlib
using Reactant.MLIR

@noinline function no(@nospecialize(x))
Expand Down Expand Up @@ -56,3 +56,34 @@ function test()
end
end
test()

@testset "Activation Functions" begin
sumabs2(f, x) = sum(abs2, f.(x))

function ∇sumabs2(f, x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
return dx
end

x_act = randn(Float32, 10, 10)
x_act_ca = Reactant.ConcreteRArray(x_act)

@testset "Activation: $act" for act in (
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
)
f_compile = Reactant.compile(sumabs2, (act, x_act))

y_simple = sumabs2(act, x_act)
y_compile = f_compile(act, x_act_ca)

∂x_enz = Enzyme.make_zero(x_act)
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))

∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))

∂x_compile = ∇sumabs2_compiled(act, x_act_ca)

@test y_simple ≈ y_compile
end
end
11 changes: 7 additions & 4 deletions test/nn_lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-ele
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
model = Lux.Chain(
Lux.Dense(2 => 3, tanh), # activation function inside layer
Lux.BatchNorm(3, gelu),
Lux.Dense(3 => 2),
softmax,
)
Expand All @@ -17,8 +18,7 @@ ps, st = Lux.setup(Xoshiro(123), model)
using BenchmarkTools

origout, _ = model(noisy, ps, st)
@show origout[3]
@btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB)
@btime model($noisy, $ps, $st) # 68.444 μs (46 allocations: 45.88 KiB)

cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete)
cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete)
Expand All @@ -31,8 +31,9 @@ f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cs
# # @show @code_typed f(cmodel,cnoisy)
# # @show @code_llvm f(cmodel,cnoisy)
comp = f(cmodel, cnoisy, cps, cst)
@show comp[3]
@btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes)
@btime f($cmodel, $cnoisy, $cps, $cst) # 21.790 μs (6 allocations: 224 bytes)

@test comp ≈ origout atol = 1e-5 rtol = 1e-2

# To train the model, we use batches of 64 samples, and one-hot encoding:

Expand Down Expand Up @@ -81,6 +82,8 @@ compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst)
)

@test length(compiled_gradient(cmodel, cnoisy, ctarget, cps, cst)) == 2

# # Training loop, using the whole data set 1000 times:
# losses = []
# for epoch in 1:1_000
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ include("nn.jl")
include("struct.jl")
include("closure.jl")
include("compile.jl")
include("nn_lux.jl")

if VERSION ≥ v"1.10-" # Lux isn't supported on 1.9
include("nn_lux.jl")
end
Loading