Skip to content

Commit

Permalink
Clear dq before launch kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
poyenc committed Dec 27, 2024
1 parent 12cd298 commit e3dd0e2
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv_expanded = dv;
}

if(head_size == 64) {
dq.zero_();
}

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());

Expand Down

0 comments on commit e3dd0e2

Please sign in to comment.