From c57b3d0a6e8900af01a6af57978ed425137b866c Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 4 Dec 2024 18:09:12 -0500 Subject: [PATCH] add RDNA CI (#105) * Add RDNA CI This is a combination of 4 commits. try navi try matrix small change try minimal change * limit navi tests * stop casting to fp32 which leads to oom on navi * enable all causal * revert all causal * skip compiler bug on navi --- .github/workflows/amd_tests.yml | 8 ++++++-- flash_attn/flash_attn_triton_amd/bwd_prefill.py | 11 ++++++----- tests/test_flash_attn_triton_amd.py | 7 +++++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index 4cab3ca5d..f8c81d92e 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -27,7 +27,7 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then - echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]' + echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]' else echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' fi @@ -59,6 +59,7 @@ jobs: export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - name: Flash Attention Tests Using Reference Impl + if: matrix.runner[1] == 'gfx90a' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" export FLASH_ATTENTION_TRITON_AMD_REF=1 @@ -68,15 +69,18 @@ jobs: export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py - name: AMD Tests + if: matrix.runner[1] == 'gfx90a' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest -v -s flash_attn/flash_attn_triton_amd/test.py - name: AMD Bench + if: matrix.runner[1] == 'gfx90a' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python flash_attn/flash_attn_triton_amd/bench.py - name: AMD Bench with Autotune + if: matrix.runner[1] == 'gfx90a' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 - python flash_attn/flash_attn_triton_amd/bench.py \ No newline at end of file + python flash_attn/flash_attn_triton_amd/bench.py diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index 286268a0c..66ab91e21 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -149,8 +149,8 @@ def _bwd_kernel_one_col_block( # load k and v once per column block k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk - k = tl.load(k_ptrs, mask=kv_mask, other=0.0).to(tl.float32) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32) + k = tl.load(k_ptrs, mask=kv_mask, other=0.0) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) # loop over rows for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M): @@ -164,8 +164,8 @@ def _bwd_kernel_one_col_block( q_mask = mask_m[:, None] & mask_d[None, :] # load q, k, v, do on-chip - q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32) - do = tl.load(do_ptrs, mask=q_mask, other=0.0).to(tl.float32) + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + do = tl.load(do_ptrs, mask=q_mask, other=0.0) # recompute p = softmax(qk, dim=-1).T qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -192,7 +192,7 @@ def _bwd_kernel_one_col_block( # mask block in the cases where the data is smaller the block size p_mask = mask_m[:, None] & mask_n[None, :] p = tl.where(p_mask, p, 0.0) - p = p.to(tl.float32) + p = p.to(tl.float16) # compute dv dv += tl.dot(tl.trans(p), do) @@ -205,6 +205,7 @@ def _bwd_kernel_one_col_block( Di = tl.load(d_ptrs, mask=mask_m) ds = (p * (dp - Di[:, None])) * sm_scale ds = tl.where(p_mask, ds, 0.0) + ds = ds.to(tl.float16) # compute dk = dot(ds.T, q) dk += tl.dot(tl.trans(ds), q) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index be3ad6f8e..fa19ac4d6 100644 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -18,7 +18,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import DEBUG +from flash_attn.flash_attn_triton_amd.utils import DEBUG, is_rdna # Test ROCM Triton Backend USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" @@ -901,7 +901,6 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1571,6 +1570,10 @@ def test_flash_attn_varlen_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if USE_TRITON_ROCM: + if is_rdna(): + if seqlen_q == 1 and seqlen_k == 239 and d == 256: + pytest.skip("This config doesnot work on RDNA Devices.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30