Skip to content

Commit

Permalink
set fall back to 64 x64
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 20, 2024
1 parent 9c0864d commit db3644d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
23 changes: 8 additions & 15 deletions flash_attn/flash_attn_triton_kernel_prefill_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,21 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri

@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1,
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches. Check why.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
# Fall-back config.
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
# num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'],
use_cuda_graph=True,
)
@triton.jit
Expand Down
11 changes: 6 additions & 5 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def skip_config(*args, reproducible=True, skip_pct = 0.95):
else:
skip_seed = time.time()

print("skip_seed:", skip_seed)
if DEBUG:
print("skip_seed:", skip_seed)
random.seed(config_str)


Expand Down Expand Up @@ -1049,8 +1050,8 @@ def test_flash_attn_output(
if softcap != 0.0:
pytest.skip("softcap not supported on AMD yet")

if causal == True:
pytest.skip("causal not supported on AMD yet")
# if causal == True:
# pytest.skip("causal not supported on AMD yet")

if local == True:
pytest.skip("local sliding window attention not supported on AMD yet")
Expand Down Expand Up @@ -1380,8 +1381,8 @@ def test_flash_attn_varlen_output(
if softcap != 0.0:
pytest.skip("softcap not supported on AMD yet")

if causal == True:
pytest.skip("causal not supported on AMD yet")
# if causal == True:
# pytest.skip("causal not supported on AMD yet")

if test_backward == True:
pytest.skip("Backward Attention not supported on AMD yet")
Expand Down

0 comments on commit db3644d

Please sign in to comment.