diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 5742e0f10..f2f9afe08 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -338,6 +338,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num dv_expanded = dv; } + if(head_size == 64) { + dq.zero_(); + } + auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator());