Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 #116

Draft
wants to merge 7 commits into
base: main_perf
Choose a base branch
from
Draft

fp8 #116

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down
146 changes: 133 additions & 13 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
Expand Down Expand Up @@ -99,10 +100,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))

# compute mask for scores
p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)

# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE
if IS_FP8:
qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized
tl.store(qk_fp8_ptrs, qk_scaled, mask=p_mask)

if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
Expand Down Expand Up @@ -131,8 +138,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
p = tl.math.exp2(q_shifted * RCP_LN2)
else:
p = tl.math.exp(q_shifted)

p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
if IS_FP8:
p *= p_inv_scale

# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
Expand Down Expand Up @@ -171,6 +178,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(v.type.element_ty), v)
if IS_FP8:
acc *= p_scale * v_scale
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
Expand Down Expand Up @@ -259,15 +268,17 @@ def get_autotune_configs():
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
def attn_fwd(Q, K, V, bias,
Q_SCALE, K_SCALE, V_SCALE, P_SCALE, P_INV_SCALE, stride_qscale_z, stride_kvscale_z, stride_pscale_z, stride_pinvscale_z,
SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, qk_fp8, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
Expand Down Expand Up @@ -396,6 +407,24 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)

# Load scale factors if IS_FP8.
if IS_FP8:
q_scale = tl.load(Q_SCALE + off_z * stride_qscale_z + off_h_q)
k_scale = tl.load(K_SCALE + off_z * stride_kvscale_z + off_h_k)
v_scale = tl.load(V_SCALE + off_z * stride_kvscale_z + off_h_k)
p_scale = tl.load(P_SCALE + off_z * stride_pscale_z + off_h_q)
p_inv_scale = tl.load(P_INV_SCALE + off_z * stride_pinvscale_z + off_h_q)
# print("q_scale", q_scale)
# print("k_scale", k_scale)
# print("v_scale", v_scale)
# print("p_scale", p_scale)
# print("p_inv_scale", p_inv_scale)
qk_fp8_offset = qk_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
qk_fp8_ptrs = qk_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm
else:
q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0
qk_fp8_ptrs = None

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
Expand All @@ -418,9 +447,10 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
Expand All @@ -447,8 +477,9 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
philox_ptrs += n_full_blocks * BLOCK_N * stride_sn
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
Expand Down Expand Up @@ -561,6 +592,92 @@ def attention_prefill_forward_triton_impl(
print("return_scores:", return_softmax)
print("use_exp2:", use_exp2)

# Define FP8 types we support
FP8_TYPES = {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}

# Simple check if tensors are FP8
is_fp8 = q.dtype in FP8_TYPES

if is_fp8:
# constants
eps = 1e-9
type_max = torch.finfo(q.dtype).max
per_head_scaling = True

# Convert to float32 for scale computation
q_float32 = q.detach().to(torch.float32)
k_float32 = k.detach().to(torch.float32)
v_float32 = v.detach().to(torch.float32)

# Get shapes for scaling
batch = q.size(0) if layout != "thd" else len(cu_seqlens_q) - 1
nheads_q = q.size(1) if layout == "bhsd" else q.size(2)
nheads_k = k.size(1) if layout == "bhsd" else k.size(2)

if per_head_scaling:
# Set up layout-specific dimensions
if layout == "bhsd":
seqlen_loc = 2
elif layout == "bshd":
seqlen_loc = 1

# Compute max for each batch-head pair across seqlen and dim
q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))
k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))
v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))

# Divide by type max
q_scale = q_scale / type_max
k_scale = k_scale / type_max
v_scale = v_scale / type_max

# Set p_scale according to reference
p_scale = torch.full((batch, nheads_q), 1.0/type_max, dtype=torch.float32, device=q.device)
p_inv_scale = 1.0 / p_scale
else:
q_max = max(q_float32.abs().max().item(), eps)
k_max = max(k_float32.abs().max().item(), eps)
v_max = max(v_float32.abs().max().item(), eps)

# Create scale tensors with the global values
q_scale = torch.full((batch, nheads_q), q_max, dtype=torch.float32, device=q.device)
k_scale = torch.full((batch, nheads_k), k_max, dtype=torch.float32, device=k.device)
v_scale = torch.full((batch, nheads_k), v_max, dtype=torch.float32, device=v.device)

# Simple p_scale for softmax computation
p_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device)
p_inv_scale = torch.full((batch, nheads_q), 1.0, dtype=torch.float32, device=q.device)

# Get strides for the kernel
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)
p_scale_stride_z = p_scale.stride(0)
p_inv_scale_stride_z = p_inv_scale.stride(0)

# dump intermedia results
qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device)
acc_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device)
else:
# For non-FP8 types, use dummy values (no scaling needed)
q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1
q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0
qk_fp8= None

if DEBUG:
print("is_fp8:", is_fp8)
print("q_scale:", q_scale)
print("k_scale:", k_scale)
print("v_scale:", v_scale)
print("p_scale:", p_scale)
print("p_inv_scale:", p_inv_scale)
print("q_scale_stride_z:", q_scale_stride_z)
print("kv_scale_stride_z:", kv_scale_stride_z)
print("p_scale_stride_z:", p_scale_stride_z)
print("p_inv_scale_stride_z:", p_inv_scale_stride_z)
if is_fp8:
print(f"type_max: {type_max}")


# check if varlen
is_varlen = layout == "thd"

Expand Down Expand Up @@ -618,15 +735,16 @@ def attention_prefill_forward_triton_impl(
else:
alibi_strides = (0, 0)


attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
attn_fwd[grid](q, k, v, bias,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, q_scale_stride_z, kv_scale_stride_z, p_scale_stride_z, p_inv_scale_stride_z,
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
qk_fp8=qk_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax)
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8)

if DEBUG:
print()
Expand All @@ -638,5 +756,7 @@ def attention_prefill_forward_triton_impl(
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_fwd")
if is_fp8:
print("qk_fp8:", qk_fp8)

return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None
91 changes: 91 additions & 0 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .bwd_prefill import attention_prefill_backward_triton_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4
from flash_attn import flash_attn_func

# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html
ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose.
Expand Down Expand Up @@ -471,6 +472,96 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
print("output_ref:", output_ref, output_ref.shape)
torch.testing.assert_close(output_triton, output_ref, atol=ATOL, rtol=RTOL)


@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
[
# (1, 1, 1, 1, 1, 1),
(1, 1, 1, 2, 4, 16),
# (1, 2, 2, 2, 4, 16),
# (1, 4, 1, 2, 4, 16),
# (1, 4, 2, 2, 4, 16),
# # (1, 1, 1, 4, 2, 16),
# (1, 1, 1, 4, 4, 16),
# (1, 2, 2, 4, 4, 16),
# (2, 1, 1, 4, 4, 16),
# (2, 2, 2, 4, 4, 16),
# (1, 1, 1, 128, 64, 16),
# (2, 2, 2, 2, 128, 1),
# (2, 3, 3, 2, 128, 16),
# (3, 2, 2, 256, 512, 16),
# (3, 3, 3, 128, 128, 64),
# (2, 4, 4, 1024, 1024, 64),
# (4, 6, 6, 108, 256, 224),
# (4, 8, 8, 2048, 2048, 128),
# (4, 16, 16, 4096, 4096, 64),
# (2, 4, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 1, 113, 203, 256),
# (4, 6, 1, 128, 217, 256),
# (4, 6, 2, 113, 211, 128),
# (4, 6, 2, 108, 256, 128),
# (4, 6, 1, 256, 512, 64),
# (4, 6, 1, 512, 256, 64),
# (4, 6, 2, 1024, 1024, 32),
# (4, 6, 2, 1023, 1024, 32),
# (4, 6, 6, 1024, 1023, 32),
# (4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('layout', ["bshd"]) # expects bshd args
@pytest.mark.parametrize('DEBUG_INPUT', [True])
def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, DEBUG_INPUT):
device = "cuda"
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False

q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)


out_fp16, lse_fp16, S_dmask_fp16 = flash_attn_func(
q.clone().to(torch.float16),
k.clone().to(torch.float16),
v.clone().to(torch.float16),
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if DEBUG:
print("out_fp16", out_fp16)
print("lse_fp16", lse_fp16)
print("S_dmask_fp16", S_dmask_fp16)

out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func(
q.clone().to(torch.float8_e4m3fnuz),
k.clone().to(torch.float8_e4m3fnuz),
v.clone().to(torch.float8_e4m3fnuz),
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if DEBUG:
print("out_fp8", out_fp8)
print("lse_fp8", lse_fp8)
print("S_dmask_fp8", S_dmask_fp8)

if DEBUG:
print("out_fp16:", out_fp16, out_fp16.shape)
print("out_fp8:", out_fp8, out_fp8.shape)
torch.testing.assert_close(out_fp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL, rtol=RTOL)

@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [
(1, 1, 1, 1, 1, 1),
Expand Down
Loading