From 77fe3a4bbe6840d3fbb8b4959234b2c67393fed9 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 3 Jan 2025 05:05:24 -0600 Subject: [PATCH 01/11] disable navi --- .github/workflows/amd_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index f8c81d92e..226248e51 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -27,7 +27,7 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]' + echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' else echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' fi From 92fb04095afaf1259cb9840ce147391a1719d834 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 3 Jan 2025 13:41:56 -0600 Subject: [PATCH 02/11] start test --- flash_attn/flash_attn_triton_amd/test.py | 70 ++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 7548743c1..d8d7149e3 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -8,6 +8,7 @@ from .bwd_prefill import attention_prefill_backward_triton_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 +from flash_attn import flash_attn_func # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. @@ -471,6 +472,75 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou print("output_ref:", output_ref, output_ref.shape) torch.testing.assert_close(output_triton, output_ref, atol=ATOL, rtol=RTOL) + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + # (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + # (1, 2, 2, 2, 4, 16), + # (1, 4, 1, 2, 4, 16), + # (1, 4, 2, 2, 4, 16), + # # (1, 1, 1, 4, 2, 16), + # (1, 1, 1, 4, 4, 16), + # (1, 2, 2, 4, 4, 16), + # (2, 1, 1, 4, 4, 16), + # (2, 2, 2, 4, 4, 16), + # (1, 1, 1, 128, 64, 16), + # (2, 2, 2, 2, 128, 1), + # (2, 3, 3, 2, 128, 16), + # (3, 2, 2, 256, 512, 16), + # (3, 3, 3, 128, 128, 64), + # (2, 4, 4, 1024, 1024, 64), + # (4, 6, 6, 108, 256, 224), + # (4, 8, 8, 2048, 2048, 128), + # (4, 16, 16, 4096, 4096, 64), + # (2, 4, 4, 8192, 8192, 32), + # # fa configs + # (4, 6, 1, 113, 203, 256), + # (4, 6, 1, 128, 217, 256), + # (4, 6, 2, 113, 211, 128), + # (4, 6, 2, 108, 256, 128), + # (4, 6, 1, 256, 512, 64), + # (4, 6, 1, 512, 256, 64), + # (4, 6, 2, 1024, 1024, 32), + # (4, 6, 2, 1023, 1024, 32), + # (4, 6, 6, 1024, 1023, 32), + # (4, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('layout', ["bshd"]) # expects bshd args +@pytest.mark.parametrize('dtype', [torch.float8_e4m3fn, torch.float16]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, dtype, DEBUG_INPUT): + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + + + out, lse, S_dmask = flash_attn_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + print("out", out) + print("lse", lse) + print("S_dmask", S_dmask) + @pytest.mark.parametrize( "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ (1, 1, 1, 1, 1, 1), From 957b0e68bf8c4e42009ace6646b81c19a772c713 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 3 Jan 2025 14:12:06 -0600 Subject: [PATCH 03/11] test fp16 against fp8 --- flash_attn/flash_attn_triton_amd/test.py | 41 ++++++++++++++++++------ 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index d8d7149e3..ff4d14269 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -512,22 +512,21 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('layout', ["bshd"]) # expects bshd args -@pytest.mark.parametrize('dtype', [torch.float8_e4m3fn, torch.float16]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) -def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, dtype, DEBUG_INPUT): +def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, DEBUG_INPUT): device = "cuda" window_size = (-1, -1) softcap = 0.0 alibi_slopes = None deterministic = False - q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) - out, lse, S_dmask = flash_attn_func( - q, - k, - v, + out_fp16, lse_fp16, S_dmask_fp16 = flash_attn_func( + q.clone().to(torch.float16), + k.clone().to(torch.float16), + v.clone().to(torch.float16), dropout_p, causal=causal, window_size=window_size, @@ -536,10 +535,32 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, deterministic=deterministic, return_attn_probs=True, ) + if DEBUG: + print("out_fp16", out_fp16) + print("lse_fp16", lse_fp16) + print("S_dmask_fp16", S_dmask_fp16) + + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func( + q.clone().to(torch.float8_e4m3fnuz), + k.clone().to(torch.float8_e4m3fnuz), + v.clone().to(torch.float8_e4m3fnuz), + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out_fp8", out_fp8) + print("lse_fp8", lse_fp8) + print("S_dmask_fp8", S_dmask_fp8) - print("out", out) - print("lse", lse) - print("S_dmask", S_dmask) + if DEBUG: + print("out_fp16:", out_fp16, out_fp16.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + torch.testing.assert_close(out_fp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL, rtol=RTOL) @pytest.mark.parametrize( "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ From a290a6d5b15fd6238530fe9944e2c136dc4795c1 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 7 Jan 2025 06:47:43 -0600 Subject: [PATCH 04/11] save scaling code so far --- .../flash_attn_triton_amd/fwd_prefill.py | 60 ++++++++++++++++--- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index c6366b8b5..4cbfb1729 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, MetaData, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH @@ -63,6 +63,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, @@ -103,6 +104,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # -- compute qk ---- qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE + if IS_FP8: + qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -135,7 +138,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) + l_ij = tl.sum(p, 1) # p is fp32 at this point if ENABLE_DROPOUT: if tl_DROPOUT_USE_PYTORCH: dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) @@ -170,7 +173,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(v.type.element_ty), v) + if IS_FP8: + acc += tl.dot((p * p_inv_scale).to(v.type.element_ty), v) * p_scale * v_scale + else: + acc += tl.dot(p.to(v.type.element_ty), v) k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: @@ -259,15 +265,17 @@ def get_autotune_configs(): use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, +def attn_fwd(Q, K, V, bias, + Q_SCALE, K_SCALE, V_SCALE, P_SCALE, P_INV_SCALE, stride_qscale_z, stride_kvscale_z, stride_pscale_z, stride_pinvscale_z, + SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr): + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -396,6 +404,16 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + # Load scale factors if IS_FP8. + if IS_FP8: + q_scale = tl.load(Q_SCALE + off_z * stride_qscale_z + off_h_q) + k_scale = tl.load(K_SCALE + off_z * stride_kvscale_z + off_h_k) + v_scale = tl.load(V_SCALE + off_z * stride_kvscale_z + off_h_k) + p_scale = tl.load(P_SCALE + off_z * stride_pscale_z + off_h_q) + p_inv_scale = tl.load(P_INV_SCALE + off_z * stride_pinvscale_z + off_h_q) + else: + q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0 + # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) @@ -421,6 +439,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ sd_mask_ptrs, dropout_mask_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, alibi_slope, + q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... @@ -449,6 +468,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, @@ -539,6 +559,22 @@ def attention_prefill_forward_triton_impl( return_softmax, use_exp2): + + if q.dtype in FP8_TYPES: + is_fp8 = True + q_scale = fp8_metadata.q_scale + k_scale = fp8_metadata.k_scale + v_scale = fp8_metadata.v_scale + p_scale = fp8_metadata.p_scale + p_inv_scale = fp8_metadata.p_inv_scale + q_scale_stride_z = q_scale.stride(0) + kv_scale_stride_z = k_scale.stride(0) + p_scale_stride_z = p_scale.stride(0) + p_inv_scale_stride_z = p_inv_scale.stride(0) + else: + q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 + q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 + if DEBUG: print() print("attention_prefill_forward_triton_impl") @@ -546,6 +582,11 @@ def attention_prefill_forward_triton_impl( print("k:", k, k.shape) print("v:", v, v.shape) print("o:", o, o.shape) + print("q_scale:", q_scale) + print("k_scale:", k_scale) + print("v_scale:", v_scale) + print("p_scale:", p_scale) + print("p_inv_scale:", p_inv_scale) print("sm_scale:", sm_scale) print("alibi_slopes:", alibi_slopes) print("causal:", causal) @@ -618,15 +659,16 @@ def attention_prefill_forward_triton_impl( else: alibi_strides = (0, 0) - - attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, + attn_fwd[grid](q, k, v, bias, + q_scale, k_scale, v_scale, p_scale, p_inv_scale, q_scale_stride_z, kv_scale_stride_z, p_scale_stride_z, p_inv_scale_stride_z, + sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax) + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8) if DEBUG: print() From 354230037b1aca2c28a77c8ba857a9026c33bcc6 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 7 Jan 2025 07:19:25 -0600 Subject: [PATCH 05/11] global scaling --- .../flash_attn_triton_amd/fwd_prefill.py | 77 +++++++++++++------ 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 4cbfb1729..955e0fc2c 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, MetaData, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH @@ -559,22 +559,6 @@ def attention_prefill_forward_triton_impl( return_softmax, use_exp2): - - if q.dtype in FP8_TYPES: - is_fp8 = True - q_scale = fp8_metadata.q_scale - k_scale = fp8_metadata.k_scale - v_scale = fp8_metadata.v_scale - p_scale = fp8_metadata.p_scale - p_inv_scale = fp8_metadata.p_inv_scale - q_scale_stride_z = q_scale.stride(0) - kv_scale_stride_z = k_scale.stride(0) - p_scale_stride_z = p_scale.stride(0) - p_inv_scale_stride_z = p_inv_scale.stride(0) - else: - q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 - q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 - if DEBUG: print() print("attention_prefill_forward_triton_impl") @@ -582,11 +566,6 @@ def attention_prefill_forward_triton_impl( print("k:", k, k.shape) print("v:", v, v.shape) print("o:", o, o.shape) - print("q_scale:", q_scale) - print("k_scale:", k_scale) - print("v_scale:", v_scale) - print("p_scale:", p_scale) - print("p_inv_scale:", p_inv_scale) print("sm_scale:", sm_scale) print("alibi_slopes:", alibi_slopes) print("causal:", causal) @@ -602,6 +581,60 @@ def attention_prefill_forward_triton_impl( print("return_scores:", return_softmax) print("use_exp2:", use_exp2) + # Define FP8 types we support + FP8_TYPES = {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} + + # Simple check if tensors are FP8 + is_fp8 = q.dtype in FP8_TYPES + + if is_fp8: + # Convert to float32 for scale computation + q_float32 = q.detach().to(torch.float32) + k_float32 = k.detach().to(torch.float32) + v_float32 = v.detach().to(torch.float32) + + # Get shapes for scaling + batch = q.size(0) if layout != "thd" else len(cu_seqlens_q) - 1 + nheads_q = q.size(1) if layout == "bhsd" else q.size(2) + nheads_k = k.size(1) if layout == "bhsd" else k.size(2) + + # Compute global max values + eps = 1e-9 + q_max = max(q_float32.abs().max().item(), eps) + k_max = max(k_float32.abs().max().item(), eps) + v_max = max(v_float32.abs().max().item(), eps) + + # Create scale tensors with the global values + q_scale = torch.full((batch, nheads_q), q_max, dtype=torch.float32, device=q.device) + k_scale = torch.full((batch, nheads_k), k_max, dtype=torch.float32, device=k.device) + v_scale = torch.full((batch, nheads_k), v_max, dtype=torch.float32, device=v.device) + + # Simple p_scale for softmax computation + p_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) + p_inv_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) + + # Get strides for the kernel + q_scale_stride_z = q_scale.stride(0) + kv_scale_stride_z = k_scale.stride(0) + p_scale_stride_z = p_scale.stride(0) + p_inv_scale_stride_z = p_inv_scale.stride(0) + else: + # For non-FP8 types, use dummy values (no scaling needed) + q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 + q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 + + if DEBUG: + print("is_fp8:", is_fp8) + print("q_scale:", q_scale) + print("k_scale:", k_scale) + print("v_scale:", v_scale) + print("p_scale:", p_scale) + print("p_inv_scale:", p_inv_scale) + print("q_scale_stride_z:", q_scale_stride_z) + print("kv_scale_stride_z:", kv_scale_stride_z) + print("p_scale_stride_z:", p_scale_stride_z) + print("p_inv_scale_stride_z:", p_inv_scale_stride_z) + # check if varlen is_varlen = layout == "thd" From 0121712cfebdddd4bc05e66db7e39aa6f5a4f23d Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 7 Jan 2025 07:54:42 -0600 Subject: [PATCH 06/11] add per_head_scaling --- .../flash_attn_triton_amd/fwd_prefill.py | 64 ++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 955e0fc2c..14dd5edda 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -411,6 +411,11 @@ def attn_fwd(Q, K, V, bias, v_scale = tl.load(V_SCALE + off_z * stride_kvscale_z + off_h_k) p_scale = tl.load(P_SCALE + off_z * stride_pscale_z + off_h_q) p_inv_scale = tl.load(P_INV_SCALE + off_z * stride_pinvscale_z + off_h_q) + # print("q_scale", q_scale) + # print("k_scale", k_scale) + # print("v_scale", v_scale) + # print("p_scale", p_scale) + # print("p_inv_scale", p_inv_scale) else: q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0 @@ -588,6 +593,11 @@ def attention_prefill_forward_triton_impl( is_fp8 = q.dtype in FP8_TYPES if is_fp8: + # constants + eps = 1e-9 + type_max = torch.finfo(q.dtype).max + per_head_scaling = True + # Convert to float32 for scale computation q_float32 = q.detach().to(torch.float32) k_float32 = k.detach().to(torch.float32) @@ -598,24 +608,45 @@ def attention_prefill_forward_triton_impl( nheads_q = q.size(1) if layout == "bhsd" else q.size(2) nheads_k = k.size(1) if layout == "bhsd" else k.size(2) - # Compute global max values - eps = 1e-9 - q_max = max(q_float32.abs().max().item(), eps) - k_max = max(k_float32.abs().max().item(), eps) - v_max = max(v_float32.abs().max().item(), eps) - - # Create scale tensors with the global values - q_scale = torch.full((batch, nheads_q), q_max, dtype=torch.float32, device=q.device) - k_scale = torch.full((batch, nheads_k), k_max, dtype=torch.float32, device=k.device) - v_scale = torch.full((batch, nheads_k), v_max, dtype=torch.float32, device=v.device) - - # Simple p_scale for softmax computation - p_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) - p_inv_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) + if per_head_scaling: + # Set up layout-specific dimensions + if layout == "bhsd": + seqlen_loc = 2 + dim_loc = 3 + elif layout == "bshd": + seqlen_loc = 1 + dim_loc = 3 + + # Compute max for each batch-head pair across seqlen and dim + q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps)) + k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps)) + v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps)) + + # Divide by type max + q_scale = q_scale / type_max + k_scale = k_scale / type_max + v_scale = v_scale / type_max + + # Set p_scale according to reference + p_scale = torch.full((batch, nheads_q), 1.0/type_max, dtype=torch.float32, device=q.device) + p_inv_scale = 1.0 / p_scale + else: + q_max = max(q_float32.abs().max().item(), eps) + k_max = max(k_float32.abs().max().item(), eps) + v_max = max(v_float32.abs().max().item(), eps) + + # Create scale tensors with the global values + q_scale = torch.full((batch, nheads_q), q_max, dtype=torch.float32, device=q.device) + k_scale = torch.full((batch, nheads_k), k_max, dtype=torch.float32, device=k.device) + v_scale = torch.full((batch, nheads_k), v_max, dtype=torch.float32, device=v.device) + + # Simple p_scale for softmax computation + p_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) + p_inv_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device) # Get strides for the kernel q_scale_stride_z = q_scale.stride(0) - kv_scale_stride_z = k_scale.stride(0) + kv_scale_stride_z = k_scale.stride(0) p_scale_stride_z = p_scale.stride(0) p_inv_scale_stride_z = p_inv_scale.stride(0) else: @@ -634,6 +665,9 @@ def attention_prefill_forward_triton_impl( print("kv_scale_stride_z:", kv_scale_stride_z) print("p_scale_stride_z:", p_scale_stride_z) print("p_inv_scale_stride_z:", p_inv_scale_stride_z) + if is_fp8: + print(f"type_max: {type_max}") + # check if varlen is_varlen = layout == "thd" From 1cef817e7c907b5bd8ccdf5668b3d1100740e97f Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 7 Jan 2025 10:58:44 -0600 Subject: [PATCH 07/11] dump qk --- .../flash_attn_triton_amd/fwd_prefill.py | 45 ++++++++++++------- flash_attn/flash_attn_triton_amd/test.py | 2 +- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 14dd5edda..30a5a6daa 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -62,7 +62,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, @@ -100,12 +100,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) + + # compute mask for scores + p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) # -- compute qk ---- qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE if IS_FP8: qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized + tl.store(qk_fp8_ptrs, qk_scaled, mask=p_mask) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -134,11 +138,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri p = tl.math.exp2(q_shifted * RCP_LN2) else: p = tl.math.exp(q_shifted) - - p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + if IS_FP8: + p *= p_inv_scale # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) # p is fp32 at this point + l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: if tl_DROPOUT_USE_PYTORCH: dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) @@ -173,10 +177,9 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij + acc += tl.dot(p.to(v.type.element_ty), v) if IS_FP8: - acc += tl.dot((p * p_inv_scale).to(v.type.element_ty), v) * p_scale * v_scale - else: - acc += tl.dot(p.to(v.type.element_ty), v) + acc *= p_scale * v_scale k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: @@ -271,7 +274,7 @@ def attn_fwd(Q, K, V, bias, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, qk_fp8, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -416,8 +419,11 @@ def attn_fwd(Q, K, V, bias, # print("v_scale", v_scale) # print("p_scale", p_scale) # print("p_inv_scale", p_inv_scale) + qk_fp8_offset = qk_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + qk_fp8_ptrs = qk_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm else: q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0 + qk_fp8_ptrs = None # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -441,7 +447,7 @@ def attn_fwd(Q, K, V, bias, block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, @@ -471,7 +477,7 @@ def attn_fwd(Q, K, V, bias, philox_ptrs += n_full_blocks * BLOCK_N * stride_sn acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, @@ -612,15 +618,13 @@ def attention_prefill_forward_triton_impl( # Set up layout-specific dimensions if layout == "bhsd": seqlen_loc = 2 - dim_loc = 3 elif layout == "bshd": seqlen_loc = 1 - dim_loc = 3 - + # Compute max for each batch-head pair across seqlen and dim - q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps)) - k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps)) - v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps)) + q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps)) + k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps)) + v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps)) # Divide by type max q_scale = q_scale / type_max @@ -649,10 +653,15 @@ def attention_prefill_forward_triton_impl( kv_scale_stride_z = k_scale.stride(0) p_scale_stride_z = p_scale.stride(0) p_inv_scale_stride_z = p_inv_scale.stride(0) + + # dump intermedia results + qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device) + acc_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device) else: # For non-FP8 types, use dummy values (no scaling needed) q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 + qk_fp8= None if DEBUG: print("is_fp8:", is_fp8) @@ -731,7 +740,7 @@ def attention_prefill_forward_triton_impl( sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + qk_fp8=qk_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p @@ -747,5 +756,7 @@ def attention_prefill_forward_triton_impl( print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) write_dropout_mask(dropout_mask, "dropout_mask_fwd") + if is_fp8: + print("qk_fp8:", qk_fp8) return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index ff4d14269..9dcf63681 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -512,7 +512,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('layout', ["bshd"]) # expects bshd args -@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.parametrize('DEBUG_INPUT', [True]) def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, DEBUG_INPUT): device = "cuda" window_size = (-1, -1) From 390e9906c0823c930868c0507b2b34569a4f9f27 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 8 Jan 2025 11:26:00 -0600 Subject: [PATCH 08/11] save dumping q, k and qk to fp32 tensor --- .../flash_attn_triton_amd/fwd_prefill.py | 52 +++++++++++++++---- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 30a5a6daa..9973cbb09 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -62,7 +62,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + q_fp8_ptrs, k_fp8_ptrs, qk_f32_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, @@ -101,15 +101,22 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - # compute mask for scores - p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + # compute masks + q_mask = (OFFS_M[:, None] < actual_seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + p_mask = q_mask & k_mask + + if IS_FP8: + tl.store(q_fp8_ptrs, q, mask=q_mask) + tl.store(k_fp8_ptrs, k, mask=k_mask) + tl.store(qk_f32_ptrs, tl.dot(q, k), mask=p_mask) + # -- compute qk ---- qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE if IS_FP8: qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized - tl.store(qk_fp8_ptrs, qk_scaled, mask=p_mask) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -274,7 +281,7 @@ def attn_fwd(Q, K, V, bias, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, qk_fp8, HQ: tl.constexpr, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, q_fp8, k_fp8, qk_fp8, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -406,6 +413,7 @@ def attn_fwd(Q, K, V, bias, if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + print("q:", q) # Load scale factors if IS_FP8. if IS_FP8: @@ -419,10 +427,20 @@ def attn_fwd(Q, K, V, bias, # print("v_scale", v_scale) # print("p_scale", p_scale) # print("p_inv_scale", p_inv_scale) + q_fp8_offset = q_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_fp8_ptrs = q_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + + + k_fp8_offset = k_fp8 + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_fp8_ptrs = k_fp8_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + + qk_fp8_offset = qk_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm qk_fp8_ptrs = qk_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm else: q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0 + q_fp8_ptrs = None + k_fp8_ptrs = None qk_fp8_ptrs = None # Here we compute how many full and masked blocks we have. @@ -447,7 +465,7 @@ def attn_fwd(Q, K, V, bias, block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, @@ -477,7 +495,7 @@ def attn_fwd(Q, K, V, bias, philox_ptrs += n_full_blocks * BLOCK_N * stride_sn acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, @@ -655,12 +673,17 @@ def attention_prefill_forward_triton_impl( p_inv_scale_stride_z = p_inv_scale.stride(0) # dump intermedia results - qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device) - acc_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device) + q_fp8 = torch.zeros_like(q) + k_fp8 = torch.zeros_like(k) + # NOTE: the result of fp8 dot is float32 + qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device) + acc_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device) else: # For non-FP8 types, use dummy values (no scaling needed) q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0 + q_fp8 = None + k_fp8 = None qk_fp8= None if DEBUG: @@ -735,12 +758,17 @@ def attention_prefill_forward_triton_impl( else: alibi_strides = (0, 0) + if DEBUG: + print("attn_fwd input") + print("q:", q) + print("k:", k) + attn_fwd[grid](q, k, v, bias, q_scale, k_scale, v_scale, p_scale, p_inv_scale, q_scale_stride_z, kv_scale_stride_z, p_scale_stride_z, p_inv_scale_stride_z, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - qk_fp8=qk_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + q_fp8=q_fp8, k_fp8 = k_fp8, qk_fp8=qk_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p @@ -757,6 +785,10 @@ def attention_prefill_forward_triton_impl( print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) write_dropout_mask(dropout_mask, "dropout_mask_fwd") if is_fp8: + print("") + print("q_fp8:", q_fp8) + print("k_fp8:", k_fp8) print("qk_fp8:", qk_fp8) + print("qk_fp8_ref:", torch.matmul(q.to(torch.float32).transpose(1, 2), k.to(torch.float32).transpose(1, 2).transpose(-2, -1))) return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None From e834dd26b6a6d50f3aaa8f3b7afff80b6bbb0e00 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 9 Jan 2025 03:27:32 -0600 Subject: [PATCH 09/11] fix pointer bug --- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 9973cbb09..a0f39ea3a 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -62,7 +62,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - q_fp8_ptrs, k_fp8_ptrs, qk_f32_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, @@ -109,7 +109,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if IS_FP8: tl.store(q_fp8_ptrs, q, mask=q_mask) tl.store(k_fp8_ptrs, k, mask=k_mask) - tl.store(qk_f32_ptrs, tl.dot(q, k), mask=p_mask) + tl.store(qk_fp8_ptrs, tl.dot(q, k), mask=p_mask) # -- compute qk ---- @@ -197,6 +197,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if ENABLE_DROPOUT: dropout_mask_ptrs += BLOCK_N * stride_sn philox_ptrs += BLOCK_N * stride_sn + if IS_FP8: + qk_fp8_ptrs += BLOCK_N * stride_sn return acc, l_i, m_i @@ -493,6 +495,8 @@ def attn_fwd(Q, K, V, bias, if ENABLE_DROPOUT: dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn philox_ptrs += n_full_blocks * BLOCK_N * stride_sn + if IS_FP8: + qk_fp8_ptrs += n_full_blocks * BLOCK_N * stride_sn acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, @@ -736,7 +740,10 @@ def attention_prefill_forward_triton_impl( else: sd_mask = None dropout_mask = None - scores_strides = (0, 0, 0, 0) + if is_fp8: + scores_strides = (qk_fp8.stride(0), qk_fp8.stride(1), qk_fp8.stride(2), qk_fp8.stride(3)) + else: + scores_strides = (0, 0, 0, 0) # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) if is_varlen: From 2a4899a11b69671458e759d76b5be9823ee0bdf1 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 9 Jan 2025 03:27:59 -0600 Subject: [PATCH 10/11] save reproducer --- flash_attn/flash_attn_triton_amd/fp8.py | 141 ++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 flash_attn/flash_attn_triton_amd/fp8.py diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py new file mode 100644 index 000000000..e53c9f277 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/fp8.py @@ -0,0 +1,141 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_fp8_kernel_no_loop( + A_ptr, # [M, K] in FP8 + B_ptr, # [K, N] in FP8 + C_ptr, # [M, N] in float32 (for storing the result) + M, N, K, + stride_am, stride_ak, # strides for A + stride_bk, stride_bn, # strides for B + stride_cm, stride_cn, # strides for C + BLOCK_M: tl.constexpr, # tile size along M + BLOCK_N: tl.constexpr, # tile size along N + BLOCK_K: tl.constexpr # tile size along K (no loop: must cover entire K or partial only) +): + """ + Simple matmul kernel that takes: + - Two FP8 matrices A and B + - Writes a float32 result C + WITHOUT looping over K. Only one chunk of size BLOCK_K is processed. + + This kernel is for demonstration and testing only. + If K > BLOCK_K, it accumulates only part of the product. + """ + # 2D block indices along M and N + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Each program instance computes a [BLOCK_M x BLOCK_N] tile in C + row_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # ----------------------- + # 1) Create an accumulator + # ----------------------- + c_tile = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # ----------------------- + # 2) Load one slice of A and B + # ----------------------- + # We skip the usual loop over K so we assume K <= BLOCK_K + # or we only compute partial coverage for K if K < BLOCK_K. + k_offsets = tl.arange(0, BLOCK_K) + + # Addressing for A: A[row, k] + a_ptrs = A_ptr + (row_offsets[:, None] * stride_am + k_offsets[None, :] * stride_ak) + # Addressing for B: B[k, col] + b_ptrs = B_ptr + (k_offsets[:, None] * stride_bk + col_offsets[None, :] * stride_bn) + + # Load from FP8 into float32 + # Here we do trivial boundary checks: + a_mask = (row_offsets[:, None] < M) & (k_offsets[None, :] < K) + b_mask = (k_offsets[:, None] < K) & (col_offsets[None, :] < N) + + A_tile_fp8 = tl.load(a_ptrs, mask=a_mask, other=0.0) + B_tile_fp8 = tl.load(b_ptrs, mask=b_mask, other=0.0) + + print("A_tile_fp8:", A_tile_fp8) + print("B_tile_fp8:", B_tile_fp8) + + # ----------------------- + # 3) Compute the dot-product + # ----------------------- + c_tile += tl.dot(A_tile_fp8, B_tile_fp8) + print("c_tile:", c_tile) + + # ----------------------- + # 4) Write results to C + # ----------------------- + c_ptrs = C_ptr + (row_offsets[:, None] * stride_cm + col_offsets[None, :] * stride_cn) + out_of_bounds = (row_offsets[:, None] >= M) | (col_offsets[None, :] >= N) + tl.store(c_ptrs, c_tile, mask=~out_of_bounds) + + +def matmul_fp8_no_loop(A_fp8: torch.Tensor, B_fp8: torch.Tensor): + """ + Minimal test function: + - A_fp8: [M, K] in FP8 + - B_fp8: [K, N] in FP8 + Returns C in float32, ignoring any leftover if K < BLOCK_K. + """ + + M, K = A_fp8.shape + K2, N = B_fp8.shape + assert K == K2, "Incompatible shapes for matmul!" + + # Pick block sizes. We want BLOCK_K >= K for a single slice coverage. + BLOCK_M = 64 + BLOCK_N = 64 + BLOCK_K = 64 # or something >= K + + # Allocate output + C = torch.zeros((M, N), device=A_fp8.device, dtype=torch.float32) + + # Launch grid + grid = ( + ( (M + BLOCK_M - 1) // BLOCK_M ), # how many blocks in M + ( (N + BLOCK_N - 1) // BLOCK_N ), # how many blocks in N + ) + + # Grab strides (row-major). + # (For FP8, these are still just standard strides in terms of # of elements.) + stride_am = A_fp8.stride(0) + stride_ak = A_fp8.stride(1) + stride_bk = B_fp8.stride(0) + stride_bn = B_fp8.stride(1) + stride_cm = C.stride(0) + stride_cn = C.stride(1) + + # If K > BLOCK_K, the result is only partial. + # For full correctness, K must be <= BLOCK_K (or block multiple). + matmul_fp8_kernel_no_loop[grid]( + A_fp8, B_fp8, C, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K + ) + return C + + +# Suppose we have small M, N, K for demonstration +M, N, K = 2, 4, 16 + +# Create random FP8 input data +if True: + A_fp8 = torch.arange(M, dtype=torch.float32, device='cuda').view(-1, 1).expand(-1, K).to(torch.float8_e4m3fnuz) + B_fp8 = torch.arange(N, dtype=torch.float32, device='cuda').view(-1, 1).expand(-1, K).to(torch.float8_e4m3fnuz) +else: + A_fp8 = torch.randn((M, K), device='cuda', dtype=torch.float32).to(torch.float8_e4m3fnuz) + B_fp8 = torch.randn((K, N), device='cuda', dtype=torch.float32).to(torch.float8_e4m3fnuz) +print("A_fp8:", A_fp8, A_fp8.shape) +print("B_fp8:", B_fp8, B_fp8.shape) +C_out = matmul_fp8_no_loop(A_fp8, B_fp8.T) +print("C:", C_out, C_out.shape) \ No newline at end of file From 65ad5f2d3425c3c207a6797d4d5fcef75770b5dd Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 9 Jan 2025 06:13:53 -0600 Subject: [PATCH 11/11] dump p and acc --- .../flash_attn_triton_amd/fwd_prefill.py | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index a0f39ea3a..fde000b2c 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -62,7 +62,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, p_fp8_ptrs, acc_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, @@ -109,14 +109,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if IS_FP8: tl.store(q_fp8_ptrs, q, mask=q_mask) tl.store(k_fp8_ptrs, k, mask=k_mask) - tl.store(qk_fp8_ptrs, tl.dot(q, k), mask=p_mask) - # -- compute qk ---- qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE if IS_FP8: qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized + tl.store(qk_fp8_ptrs, qk_scaled, mask=p_mask) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -147,6 +146,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri p = tl.math.exp(q_shifted) if IS_FP8: p *= p_inv_scale + tl.store(p_fp8_ptrs, p, mask=p_mask) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) @@ -199,6 +199,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri philox_ptrs += BLOCK_N * stride_sn if IS_FP8: qk_fp8_ptrs += BLOCK_N * stride_sn + p_fp8_ptrs += BLOCK_N * stride_sn + return acc, l_i, m_i @@ -283,7 +285,7 @@ def attn_fwd(Q, K, V, bias, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, q_fp8, k_fp8, qk_fp8, HQ: tl.constexpr, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, q_fp8, k_fp8, qk_fp8, p_fp8, acc_fp8, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -415,7 +417,6 @@ def attn_fwd(Q, K, V, bias, if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - print("q:", q) # Load scale factors if IS_FP8. if IS_FP8: @@ -429,6 +430,7 @@ def attn_fwd(Q, K, V, bias, # print("v_scale", v_scale) # print("p_scale", p_scale) # print("p_inv_scale", p_inv_scale) + q_fp8_offset = q_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm q_fp8_ptrs = q_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk @@ -439,11 +441,19 @@ def attn_fwd(Q, K, V, bias, qk_fp8_offset = qk_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm qk_fp8_ptrs = qk_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm + + p_fp8_offset = p_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + p_fp8_ptrs = p_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm + + acc_fp8_offset = acc_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + acc_fp8_ptrs = acc_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk else: q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0 q_fp8_ptrs = None k_fp8_ptrs = None qk_fp8_ptrs = None + p_fp8_ptrs = None + acc_fp8_ptrs = None # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -467,7 +477,7 @@ def attn_fwd(Q, K, V, bias, block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, p_fp8_ptrs, acc_fp8_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, @@ -497,10 +507,11 @@ def attn_fwd(Q, K, V, bias, philox_ptrs += n_full_blocks * BLOCK_N * stride_sn if IS_FP8: qk_fp8_ptrs += n_full_blocks * BLOCK_N * stride_sn + p_fp8_ptrs += n_full_blocks * BLOCK_N * stride_sn acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, + sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, p_fp8_ptrs, acc_fp8_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... @@ -510,6 +521,8 @@ def attn_fwd(Q, K, V, bias, # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. l_recip = 1 / l_i[:, None] acc = acc * l_recip + if IS_FP8: + tl.store(acc_fp8_ptrs, acc) if ENABLE_DROPOUT: dropout_scale = 1 / (1 - dropout_p) acc = acc * dropout_scale @@ -520,7 +533,6 @@ def attn_fwd(Q, K, V, bias, end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) @@ -681,7 +693,8 @@ def attention_prefill_forward_triton_impl( k_fp8 = torch.zeros_like(k) # NOTE: the result of fp8 dot is float32 qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device) - acc_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device) + p_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device) + acc_fp8 = torch.zeros(o.shape, dtype=torch.float32, device=q.device) else: # For non-FP8 types, use dummy values (no scaling needed) q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1 @@ -689,6 +702,8 @@ def attention_prefill_forward_triton_impl( q_fp8 = None k_fp8 = None qk_fp8= None + p_fp8 = None + acc_fp8 = None if DEBUG: print("is_fp8:", is_fp8) @@ -775,7 +790,7 @@ def attention_prefill_forward_triton_impl( sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - q_fp8=q_fp8, k_fp8 = k_fp8, qk_fp8=qk_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + q_fp8=q_fp8, k_fp8 = k_fp8, qk_fp8=qk_fp8, p_fp8=p_fp8, acc_fp8=acc_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p @@ -796,6 +811,7 @@ def attention_prefill_forward_triton_impl( print("q_fp8:", q_fp8) print("k_fp8:", k_fp8) print("qk_fp8:", qk_fp8) - print("qk_fp8_ref:", torch.matmul(q.to(torch.float32).transpose(1, 2), k.to(torch.float32).transpose(1, 2).transpose(-2, -1))) + print("p_fp8:", p_fp8) + print("acc_fp8:", acc_fp8) return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None