Skip to content

Commit

Permalink
fix: overload scalar ops
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 2, 2024
1 parent 2c65ae7 commit 343a45d
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,9 @@ for (jlop, hloop) in (
end
end

function Reactant.elem_apply(
::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return max.(lhs, zero(ElType))
end
NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T))

function Reactant.elem_apply(
::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
# See https://arxiv.org/pdf/1606.08415v5 Section 2
return lhs .* sigmoid.(ElType(1.702) .* lhs)
end
NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x)

# TODO handle non finite cases
function NNlib.softmax!(
Expand Down

0 comments on commit 343a45d

Please sign in to comment.