From 530f4079351fa0b88349dcf3e0cea84888999790 Mon Sep 17 00:00:00 2001 From: Clint Greene Date: Wed, 29 May 2024 16:18:14 -0500 Subject: [PATCH] Fix stride issue in flash_attn_interface --- flash_attn/flash_attn_interface.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 4f55f0c57..b97d0a1af 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -60,9 +60,15 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + + if out.stride() != dout.stride(): + out = out.as_strided(dout.size(),dout.stride()) + if dq.stride() != q.stride(): + dq = dq.as_strided(q.size(),q.stride()) + dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None, rng_state @@ -73,7 +79,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state=None): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( @@ -232,7 +238,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dq, dk, dv = torch.empty_strided(q.size(),q.stride(), dtype=q.dtype, device=q.device), torch.empty_strided(k.size(), k.stride(), dtype=k.dtype, device=k.device), torch.empty_strided(v.size(), v.stride(), dtype=v.dtype, device=v.device) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal,