Skip to content

Commit

Permalink
add RDNA CI (#105)
Browse files Browse the repository at this point in the history
* Add RDNA CI

This is a combination of 4 commits.

try navi

try matrix

small change

try minimal change

* limit navi tests

* stop casting to fp32 which leads to oom on navi

* enable all causal

* revert all causal

* skip compiler bug on navi
  • Loading branch information
micmelesse authored Dec 4, 2024
1 parent 1fcc51b commit c57b3d0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down Expand Up @@ -59,6 +59,7 @@ jobs:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
- name: Flash Attention Tests Using Reference Impl
if: matrix.runner[1] == 'gfx90a'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
Expand All @@ -68,15 +69,18 @@ jobs:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
- name: AMD Tests
if: matrix.runner[1] == 'gfx90a'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
- name: AMD Bench
if: matrix.runner[1] == 'gfx90a'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
- name: AMD Bench with Autotune
if: matrix.runner[1] == 'gfx90a'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1
python flash_attn/flash_attn_triton_amd/bench.py
python flash_attn/flash_attn_triton_amd/bench.py
11 changes: 6 additions & 5 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def _bwd_kernel_one_col_block(
# load k and v once per column block
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)

# loop over rows
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
Expand All @@ -164,8 +164,8 @@ def _bwd_kernel_one_col_block(
q_mask = mask_m[:, None] & mask_d[None, :]

# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=q_mask, other=0.0).to(tl.float32)
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
do = tl.load(do_ptrs, mask=q_mask, other=0.0)

# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand All @@ -192,7 +192,7 @@ def _bwd_kernel_one_col_block(
# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)
p = p.to(tl.float32)
p = p.to(tl.float16)

# compute dv
dv += tl.dot(tl.trans(p), do)
Expand All @@ -205,6 +205,7 @@ def _bwd_kernel_one_col_block(
Di = tl.load(d_ptrs, mask=mask_m)
ds = (p * (dp - Di[:, None])) * sm_scale
ds = tl.where(p_mask, ds, 0.0)
ds = ds.to(tl.float16)

# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.flash_attn_triton_amd.utils import DEBUG
from flash_attn.flash_attn_triton_amd.utils import DEBUG, is_rdna

# Test ROCM Triton Backend
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
Expand Down Expand Up @@ -901,7 +901,6 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
Expand Down Expand Up @@ -1571,6 +1570,10 @@ def test_flash_attn_varlen_output(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
if USE_TRITON_ROCM:
if is_rdna():
if seqlen_q == 1 and seqlen_k == 239 and d == 256:
pytest.skip("This config doesnot work on RDNA Devices.")
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down

0 comments on commit c57b3d0

Please sign in to comment.