Skip to content

Commit

Permalink
Fix swiglu backwards return type (Dao-AILab#1337)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntenenz authored Nov 16, 2024
1 parent 641db75 commit 7153673
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flash_attn/ops/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def sqrelu_bwd(g, x):
}
"""
swiglu_bwd_codestring = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
template <typename T> void swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
Expand Down

0 comments on commit 7153673

Please sign in to comment.