diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index f8c81d92e..21cfe22d7 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -8,7 +8,7 @@ on: branches: [main_perf] types: [checks_requested] push: - branches: [main_perf, micmelesse/upstream_pr] + branches: [main_perf] concurrency: group: ${{ github.ref }} @@ -17,70 +17,75 @@ concurrency: permissions: read-all jobs: - Runner-Preparation-AMD: - runs-on: ubuntu-latest - timeout-minutes: 30 - outputs: - matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} - steps: - - name: Prepare runner matrix - id: set-matrix - run: | - if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]' - else - echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' - fi - Integration-Tests-AMD: - needs: Runner-Preparation-AMD - if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} strategy: matrix: - runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} + runner: [linux-mi300-gpu-1] + fail-fast: false # disables failing the entire job when one matrix entry fails container: - image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + image: rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root steps: - name: Checkout uses: actions/checkout@v4 - - name: Install Triton + - name: Show Device Info run: | + rocminfo | grep gfx + - name: Uninstall Triton + run : | pip uninstall -y triton - pip install matplotlib pandas pytest + rm -rf ~/.triton + rm -rf ./triton/python/build + - name: Install Triton + run: | git clone https://github.com/triton-lang/triton cd triton - git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 - pip install --verbose -e python + git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 + pip install ninja cmake wheel pybind11 # build-time dependencies + pip install matplotlib pandas pytest # triton bench dependencies + pip install --verbose --no-build-isolation ./python cd .. + - name: Show Triton version + run: | + pip show triton - name: Build run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - - name: Flash Attention Tests Using Reference Impl - if: matrix.runner[1] == 'gfx90a' - run: | - export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - export FLASH_ATTENTION_TRITON_AMD_REF=1 - pytest tests/test_flash_attn_triton_amd.py - - name: Flash Attention Tests + + # CDNA Tests + - name: Flash Attention CDNA Tests + if: matrix.runner == 'linux-mi300-gpu-1' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py - name: AMD Tests - if: matrix.runner[1] == 'gfx90a' + if: matrix.runner == 'linux-mi300-gpu-1' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest -v -s flash_attn/flash_attn_triton_amd/test.py - name: AMD Bench - if: matrix.runner[1] == 'gfx90a' + if: matrix.runner == 'linux-mi300-gpu-1' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python flash_attn/flash_attn_triton_amd/bench.py - name: AMD Bench with Autotune - if: matrix.runner[1] == 'gfx90a' + if: matrix.runner == 'linux-mi300-gpu-1' run: | export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py + + # RDNA Tests + - name: Flash Attention Tests Using Reference Impl + if: matrix.runner == 'gfx1100' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + export FLASH_ATTENTION_TRITON_AMD_REF=1 + pytest tests/test_flash_attn_triton_amd.py + - name: Flash Attention RDNA Tests + if: matrix.runner == 'gfx1100' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output tests/test_flash_attn_triton_amd.py::test_flash_attn_varlen_output tests/test_flash_attn_triton_amd.py::test_flash_attn_kvcache diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1018349de..8974b2619 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -90,7 +90,11 @@ def _flash_attn_forward( window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], - return_softmax: bool + return_softmax: bool, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( @@ -107,6 +111,10 @@ def _flash_attn_forward( softcap, return_softmax, None, + descale_q, + descale_k, + descale_v, + descale_p ) return out, softmax_lse, S_dmask, rng_state @@ -164,6 +172,10 @@ def _flash_attn_varlen_forward( block_table: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( @@ -188,6 +200,10 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + descale_q, + descale_k, + descale_v, + descale_p ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -804,6 +820,10 @@ def forward( alibi_slopes, deterministic, return_softmax, + descale_q, + descale_k, + descale_v, + descale_p ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -824,6 +844,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p @@ -867,7 +891,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -890,6 +914,10 @@ def forward( deterministic, return_softmax, block_table, + descale_q, + descale_k, + descale_v, + descale_p ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -915,6 +943,10 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state @@ -966,7 +998,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -1116,6 +1148,10 @@ def flash_attn_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -1177,6 +1213,10 @@ def flash_attn_func( alibi_slopes, deterministic, return_attn_probs, + descale_q, + descale_k, + descale_v, + descale_p ) @@ -1353,6 +1393,10 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1426,6 +1470,10 @@ def flash_attn_varlen_func( deterministic, return_attn_probs, block_table, + descale_q, + descale_k, + descale_v, + descale_p ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index c6366b8b5..19ae4b139 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, arch_supports_fp8, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH @@ -63,6 +63,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + descale_q, descale_k, descale_v, descale_p, IS_FP8: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, @@ -99,9 +100,17 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - + + # compute masks + q_mask = (OFFS_M[:, None] < actual_seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + p_mask = q_mask & k_mask + # -- compute qk ---- - qk += tl.dot(q, k) + if IS_FP8 : + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE if IS_CAUSAL: @@ -132,8 +141,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri else: p = tl.math.exp(q_shifted) - p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: @@ -170,7 +177,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(v.type.element_ty), v) + + if IS_FP8: + p *= (1.0/ descale_p) # put p into fp8 range + acc += (tl.dot(p.to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: @@ -259,15 +272,17 @@ def get_autotune_configs(): use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, +def attn_fwd(Q, K, V, bias, + DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z, + SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr): + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -396,6 +411,15 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + # Load scale factors if IS_FP8. + if IS_FP8: + descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q) + descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k) + descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k) + descale_p = tl.load(DESCALE_P + off_z * stride_p_inv_scale_z + off_h_q) + else: + descale_q, descale_k, descale_v, descale_p = 1.0, 1.0, 1.0, 1.0 + # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) @@ -421,6 +445,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ sd_mask_ptrs, dropout_mask_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, alibi_slope, + descale_q, descale_k, descale_v, descale_p, IS_FP8, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... @@ -448,7 +473,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, + n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, descale_p, IS_FP8, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, @@ -467,7 +492,6 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) @@ -537,7 +561,12 @@ def attention_prefill_forward_triton_impl( philox_offset, # misc return_softmax, - use_exp2): + use_exp2, + # fp8 + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None): if DEBUG: print() @@ -561,6 +590,50 @@ def attention_prefill_forward_triton_impl( print("return_scores:", return_softmax) print("use_exp2:", use_exp2) + is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} + if is_fp8: + if DEBUG: + print("IS_FP8") + + type_max = torch.finfo(q.dtype).max + if layout == "bshd": + batch, _ , nheads_q, dim = q.shape + _, _ , nheads_k, _ = k.shape + elif layout == "bhsd": + batch, nheads_q,_, dim = q.shape + _, nheads_k, _, _ = k.shape + elif layout == "thd": + batch = len(cu_seqlens_q) - 1 + nheads_q = q.size(1) + nheads_k = k.size(1) + else: + raise ValueError("Unsupported layout") + + # Get strides for the kernel + descale_q_stride_z = descale_q.stride(0) + descale_k_stride_z = descale_k.stride(0) + descale_v_stride_z = descale_v.stride(0) + descale_p_stride_z = descale_p.stride(0) + else: + # For non-FP8 types, use dummy values (no scaling needed) + descale_q = descale_k = descale_v = descale_p = 1 + descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = descale_p_stride_z = 0 + + + if DEBUG: + print("is_fp8:", is_fp8) + print("descale_q:", descale_q) + print("descale_k:", descale_k) + print("descale_v:", descale_v) + print("descale_p:", descale_p) + print("descale_q_stride_z:", descale_q_stride_z) + print("descale_k_stride_z:", descale_k_stride_z) + print("descale_v_stride_z:", descale_v_stride_z) + print("descale_p_stride_z:", descale_p_stride_z) + if is_fp8: + print(f"type_max: {type_max}") + + # check if varlen is_varlen = layout == "thd" @@ -619,14 +692,16 @@ def attention_prefill_forward_triton_impl( alibi_strides = (0, 0) - attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, + attn_fwd[grid](q, k, v, bias, + descale_q, descale_k, descale_v, descale_p, descale_q_stride_z, descale_k_stride_z, descale_p_stride_z, + sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax) + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8) if DEBUG: print() diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 51037f236..d0b8e1d04 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -23,7 +23,11 @@ def fwd(q, window_size_right, softcap, return_softmax, - gen_): + gen_, + descale_q, + descale_k, + descale_v, + descale_p): if DEBUG: print() @@ -111,7 +115,11 @@ def fwd(q, metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2) + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_p) if DEBUG: print("fwd outputs") @@ -259,7 +267,11 @@ def varlen_fwd( window_size_right, softcap, return_softmax, - gen_): + gen_, + descale_q, + descale_k, + descale_v, + descale_p): if DEBUG: print() @@ -349,7 +361,11 @@ def varlen_fwd( metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2) + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_p) if DEBUG: print("varlen_fwd outputs") print("o:", o, o.shape) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 7548743c1..8d527e5d2 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -1,19 +1,22 @@ import torch import pytest -from .utils import DEBUG, MetaData, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref +from .utils import DEBUG, MetaData, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref, get_arch, arch_supports_fp8 from .interface_torch import attention_prefill, attention_decode from .fwd_ref import attention_forward_pytorch_ref_impl from .fwd_prefill import attention_prefill_forward_triton_impl from .bwd_prefill import attention_prefill_backward_triton_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 +from flash_attn import flash_attn_func, flash_attn_varlen_func # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. # ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues # ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs # ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs +# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 +ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # test pass with dropout and causal in fp8 EQUAL_NAN = True @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ @@ -525,6 +528,10 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou @pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors def test_op_prefill_bwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, use_exp2, layout, sequence_parallel, DEBUG_INPUT): + if get_arch() == "gfx90a": + if layout == "thd" and Z == 4 and HQ == 48 and HK == 48 and N_CTX_Q == 1024 and N_CTX_K == 1024: + pytest.skip("This config doesnot work on MI200 Devices but works on MI300.") + dtype = torch.float16 torch.manual_seed(20) # seed from test_op_bwd @@ -653,9 +660,10 @@ def test_op_prefill_bwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou @pytest.mark.parametrize('batch_size, seqlen_q, seqlen_k, group_q, group_k, dim', get_input_shapes()) def test_op_fwd_decode(batch_size, seqlen_q, seqlen_k, group_q, group_k, dim, dtype=torch.bfloat16): - if DEBUG: - print() - print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, group_q = {group_q}, group_k = {group_k}, dim = {dim}") + if get_arch() == "gfx90a": + if batch_size == 1 and seqlen_q == 1 and seqlen_k >= 65536: + pytest.skip("This config doesnot work on MI200 Devices but works on MI300.") + torch.manual_seed(20) query_group_head_size = (group_q + group_k - 1) // group_k q = (torch.empty((batch_size, seqlen_q, group_k, query_group_head_size, dim), dtype=dtype, @@ -725,3 +733,312 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) dq_ref_out = dq_attn @ dqv torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.25]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT): + device = "cuda" + window_size = (-1, -1) + softcap = None + alibi_slopes = None + deterministic = False + layout = "bshd" + + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + + # NOTE: use bfp16 becasue it fp32 trunacted + # launch kernel in fp16 + q_bfp16 = q.clone().to(torch.bfloat16) + k_bfp16 = k.clone().to(torch.bfloat16) + v_bfp16 = v.clone().to(torch.bfloat16) + out_bfp16, lse_bfp16, S_dmask_bfp16 = flash_attn_func( + q_bfp16, + k_bfp16, + v_bfp16, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out_bfp16", out_bfp16) + print("lse_bfp16", lse_bfp16) + print("S_dmask_bfp16", S_dmask_bfp16) + + # compute p for descaling + batch, _ , nheads_q, dim = q.shape + _, _ , nheads_k, _ = k.shape + + # compute max for each batch-head pair across seqlen and dim + q_max = torch.maximum(q.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) + k_max = torch.maximum(k.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) + v_max = torch.maximum(v.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) + + # scale values to fp8 range + type_max = torch.finfo(torch.float8_e4m3fnuz).max + q_fp8 = (q * type_max/ q_max).to(torch.float8_e4m3fnuz) + k_fp8 = (k * type_max/ k_max).to(torch.float8_e4m3fnuz) + v_fp8 = (v * type_max/ v_max).to(torch.float8_e4m3fnuz) + + # compute descale values + descale_q = q_max / type_max + descale_k = k_max / type_max + descale_v = v_max / type_max + descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device) + + # launch kernel in fp8 + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, + ) + if DEBUG: + print("out_fp8", out_fp8) + print("lse_fp8", lse_fp8) + print("S_dmask_fp8", S_dmask_fp8) + + if DEBUG: + print("out_bfp16:", out_bfp16, out_bfp16.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + + torch.testing.assert_close(out_bfp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8) + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.25]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT): + device = "cuda" + window_size = (-1, -1) + softcap = None + alibi_slopes = None + deterministic = False + layout = "thd" + + q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, DEBUG_INPUT=DEBUG_INPUT) + + # launch kernel in fp16 + q_bfp16 = q.clone().to(torch.bfloat16) + k_bfp16 = k.clone().to(torch.bfloat16) + v_bfp16 = v.clone().to(torch.bfloat16) + out_bfp16, lse_bfp16, S_dmask_bfp16 = flash_attn_varlen_func( + q_bfp16, + k_bfp16, + v_bfp16, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out_bfp16", out_bfp16) + print("lse_bfp16", lse_bfp16) + print("S_dmask_bfp16", S_dmask_bfp16) + + + if DEBUG: + print("q:", q, q.shape) + print("k:", k, k.shape) + + # thd + batch = len(metadata.cu_seqlens_q) - 1 + nheads_q = q.size(1) + nheads_k = k.size(1) + + if DEBUG: + print("batch:", batch) + print("nheads_q:", nheads_q) + print("nheads_k:", nheads_k) + + q_maxes = [] + k_maxes = [] + v_maxes = [] + for i in range(batch): + q_start = metadata.cu_seqlens_q[i] + q_end = metadata.cu_seqlens_q[i + 1] + k_start = metadata.cu_seqlens_k[i] + k_end = metadata.cu_seqlens_k[i + 1] + + # compute max for each batch-head pair across seqlen and dim + q_max = torch.maximum(q[q_start:q_end].abs().amax(dim=(0,2)), torch.tensor(1e-9)).unsqueeze(-1) + k_max = torch.maximum(k[k_start:k_end].abs().amax(dim=(0,2)), torch.tensor(1e-9)).unsqueeze(-1) + v_max = torch.maximum(v[k_start:k_end].abs().amax(dim=(0,2)), torch.tensor(1e-9)).unsqueeze(-1) + + q_maxes.append(q_max) + k_maxes.append(k_max) + v_maxes.append(v_max) + q_maxes = torch.stack(q_maxes) + k_maxes = torch.stack(k_maxes) + v_maxes = torch.stack(v_maxes) + if DEBUG: + print("q", q, q.shape) + print("q_maxes:", q_maxes, q_maxes.shape) + print("k", k, k.shape) + print("k_maxes:", k_maxes, k_maxes.shape) + + # ---------------------------------------------------------------- + # --- FP8 conversion part --- + # ---------------------------------------------------------------- + type_max = torch.finfo(torch.float8_e4m3fnuz).max + q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fnuz) + k_fp8 = torch.empty_like(k, dtype=torch.float8_e4m3fnuz) + v_fp8 = torch.empty_like(v, dtype=torch.float8_e4m3fnuz) + for i in range(batch): + q_start = metadata.cu_seqlens_q[i] + q_end = metadata.cu_seqlens_q[i + 1] + k_start = metadata.cu_seqlens_k[i] + k_end = metadata.cu_seqlens_k[i + 1] + + # shape [heads_q, 1], broadcast to [1, heads_q, 1] + q_scale = (type_max / q_maxes[i]).unsqueeze(0) # => [1, HQ, 1] + k_scale = (type_max / k_maxes[i]).unsqueeze(0) # => [1, HK, 1] + v_scale = (type_max / v_maxes[i]).unsqueeze(0) # => [1, HK, 1] + + # q, k, v are [L, heads, dim] slices + q_slice = q[q_start:q_end] # [seq_len_i, HQ, dim] + k_slice = k[k_start:k_end] # [seq_len_i, HK, dim] + v_slice = v[k_start:k_end] # [seq_len_i, HK, dim] + + # Convert them to FP8 + q_fp8[q_start:q_end] = (q_slice * q_scale).to(torch.float8_e4m3fnuz) + k_fp8[k_start:k_end] = (k_slice * k_scale).to(torch.float8_e4m3fnuz) + v_fp8[k_start:k_end] = (v_slice * v_scale).to(torch.float8_e4m3fnuz) + + if DEBUG: + print("q_fp8:", q_fp8, q_fp8.shape) + print("k_fp8:", k_fp8, k_fp8.shape) + + # compute descale values + descale_q = q_maxes / type_max + descale_k = k_maxes / type_max + descale_v = v_maxes / type_max + descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device) + + # launch kernel in fp8 + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_func( + q_fp8, + k_fp8, + v_fp8, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, + ) + if DEBUG: + print("out_fp8", out_fp8) + print("lse_fp8", lse_fp8) + print("S_dmask_fp8", S_dmask_fp8) + + if DEBUG: + print("out_bfp16:", out_bfp16, out_bfp16.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + + torch.testing.assert_close(out_bfp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 343425788..8465a1d7e 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1,10 +1,10 @@ import csv -import json import math import torch import os import random +import functools import triton import triton.language as tl @@ -176,38 +176,73 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cud return q, k, v, input_metadata +def random_seqlens_composition(N, Z): + # generate a random composition of N into Z positive parts. + idx = torch.randperm(N - 1)[: Z - 1] + 1 + idx, _ = torch.sort(idx) + breakpoints = torch.cat([ + torch.tensor([0], dtype=torch.long), + idx, + torch.tensor([N], dtype=torch.long), + ]) + seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) + return seqlens + def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False): torch.manual_seed(20) # Random or equal sequence lengths based on 'equal_seqlens' flag if not equal_seqlens: - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + seqlens_q = random_seqlens_composition(N_CTX_Q, Z) + seqlens_k = random_seqlens_composition(N_CTX_K, Z) else: seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32) seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32) - # Calculate cumulative sequence lengths + # calculate cumulative sequence lengths cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)]) cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)]) cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32) cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32) - # Total lengths + # total lengths total_q = cu_seqlens_q[-1].item() total_k = cu_seqlens_k[-1].item() if DEBUG_INPUT: - # Initialize q, k, v with deterministic values - q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1) - q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_() - k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) - k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() - v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) - v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() - sm_scale = 1 + sm_scale = 1.0 + + q = torch.empty(total_q, HQ, D_HEAD, dtype=dtype, device=device) + k = torch.empty(total_k, HK, D_HEAD, dtype=dtype, device=device) + v = torch.empty(total_k, HK, D_HEAD, dtype=dtype, device=device) + for i in range(Z): + q_start = cu_seqlens_q[i].item() + q_end = cu_seqlens_q[i+1].item() + q_length = q_end - q_start + k_start = cu_seqlens_k[i].item() + k_end = cu_seqlens_k[i+1].item() + k_length = k_end - k_start + + + q[q_start:q_end, :, :] = ( + torch.arange(q_length, dtype=dtype, device=device) + .view(q_length, 1, 1) + .expand(q_length, HQ, D_HEAD) + ) + k[k_start:k_end, :, :] = ( + torch.arange(k_length, dtype=dtype, device=device) + .view(k_length, 1, 1) + .expand(k_length, HK, D_HEAD) + ) + v[k_start:k_end, :, :] = ( + torch.arange(k_length, dtype=dtype, device=device) + .view(k_length, 1, 1) + .expand(k_length, HK, D_HEAD) + ) + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + else: # Initialize q, k, v with random values q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_() @@ -217,6 +252,7 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda input_metadata = MetaData(sm_scale=sm_scale) input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata @@ -325,15 +361,22 @@ def get_input_shapes(): for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] return cases +@functools.cache def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" +@functools.cache def get_arch(): return triton.runtime.driver.active.get_current_target().arch +@functools.cache def is_cdna(): - return is_hip() and get_arch() in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') - + return is_hip() and get_arch() in ('gfx908', 'gfx90a', 'gfx940', 'gfx941', 'gfx942') +@functools.cache def is_rdna(): - return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") \ No newline at end of file + return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") + +@functools.cache +def arch_supports_fp8(): + return is_hip() and get_arch() in ('gfx942') \ No newline at end of file