Skip to content

Commit

Permalink
bad varlen config
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 3, 2024
1 parent c0e7d31 commit b577610
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 30 deletions.
7 changes: 4 additions & 3 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,10 @@ def attention_prefill_backward_triton_impl(
print("dk:", dk, dk.shape)
print("dq:", dq, dq.shape)
print("copy_back:", copy_back)
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_bwd")
if use_dropout:
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_bwd")

if copy_back["dq"]:
dq_og.copy_(dq)
Expand Down
11 changes: 6 additions & 5 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,12 @@ def attention_prefill_forward_triton_impl(
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if return_softmax:
use_dropout = (dropout_p > 0.0)
if use_dropout or return_softmax:
sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)
dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
dtype=torch.float32)

scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3))
else:
sd_mask = None
Expand Down Expand Up @@ -622,8 +622,9 @@ def attention_prefill_forward_triton_impl(
print("o:", o, o.shape)
print("softmax_lse:", softmax_lse, softmax_lse.shape)
print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None)
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_fwd")
if use_dropout:
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_fwd")

return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None
45 changes: 23 additions & 22 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,8 +1195,8 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('kvpacked', [False])
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize('mha_type', ["mha"])
# @pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("alibi", [False, True])
Expand All @@ -1205,28 +1205,29 @@ def test_flash_attn_output(
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32])
@pytest.mark.parametrize('d', [32])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 147),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
(32, 32),
# (1, 147),
# (113, 203),
# (128, 217),
# (113, 211),
# (108, 256),
# (256, 512),
# (512, 256),
# (1024, 1024),
# (1023, 1024),
# (1024, 1023),
# (2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.17])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.17])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_varlen_output(
Expand Down Expand Up @@ -1254,20 +1255,20 @@ def test_flash_attn_varlen_output(
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
q = torch.ones(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap

if kvpacked:
kv = torch.randn(
kv = torch.ones(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
k = torch.ones(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
v = torch.ones(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)

Expand Down Expand Up @@ -1457,7 +1458,7 @@ def test_flash_attn_varlen_output(
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
g = torch.ones_like(out)
if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
if kvpacked:
(
Expand Down

0 comments on commit b577610

Please sign in to comment.