diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 72e467c9..ca2e3056 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -22,18 +22,9 @@ for (jlop, hloop) in ( end end -function Reactant.elem_apply( - ::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N} -) where {ElType,Shape,N} - return max.(lhs, zero(ElType)) -end +NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T)) -function Reactant.elem_apply( - ::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N} -) where {ElType,Shape,N} - # See https://arxiv.org/pdf/1606.08415v5 Section 2 - return lhs .* sigmoid.(ElType(1.702) .* lhs) -end +NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x) # TODO handle non finite cases function NNlib.softmax!(