Skip to content

Commit

Permalink
Add the option for the macro and note (Dao-AILab#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Mar 28, 2024
1 parent 3e9414f commit 23e8fa5
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion csrc/flash_attn/src/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,14 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
}
Expand Down

0 comments on commit 23e8fa5

Please sign in to comment.