Skip to content

Commit

Permalink
fix: fp8 ref matches kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Dec 9, 2024
1 parent 89d3d7d commit c65af82
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
6 changes: 3 additions & 3 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
if IS_FP8:
p_scale = 1
p_scaled = (p / p_scale)
acc += tl.dot(p.to(v.type.element_ty), (v*v_scale).to(v.type.element_ty)) # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale
acc += tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale * p_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale
else:
acc += tl.dot(p.to(v.type.element_ty), v).to(tl.float32) # NOTE: acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) PASSES

# NOTE: if you make the below operation tl.float16 + set FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE=1. It passes. --> acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) PASSES
acc += tl.dot(p.to(v.type.element_ty), v).to(tl.float32)

k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
Expand Down
36 changes: 30 additions & 6 deletions flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
import torch
import math
from .utils import DEBUG
from .utils import create_scale_tensors, check_is_fp8, DEBUG

DEBUG_CORE = False

def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2):
def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2):
is_fp8 = check_is_fp8(q)
if is_fp8:
# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)

# scale qkv tensors if FP8
q = q / q_scale
k = k / k_scale
v = v / v_scale
else:
q_scale = k_scale = v_scale = 1
if DEBUG_CORE:
print()
print("attention_forward_core_ref_impl")
Expand All @@ -17,14 +30,22 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox
print("philox_seed:", philox_seed)
print("philox_offset:", philox_offset)
print("use_exp2:", use_exp2)
print('layout:', layout)
print('is_fp8:', is_fp8)
print('q_scale:', q_scale)
print('k_scale:', k_scale)
print('v_scale:', v_scale)

# cast to float32
q = q.to(torch.float32)
k = k.to(torch.float32)
v = v.to(torch.float32)

# Compute attention scores
attention_scores = torch.matmul(q, k.transpose(-2, -1))
if is_fp8:
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) * q_scale * v_scale
else:
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
if DEBUG_CORE:
print("attention_scores:", attention_scores, attention_scores.shape)

Expand Down Expand Up @@ -129,7 +150,10 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox
print("softmax_lse:", softmax_lse, softmax_lse.shape)

# Compute output
o = torch.matmul(p, v)
if is_fp8:
o = torch.matmul(p, v.to(torch.float32)) * v_scale
else:
o = torch.matmul(p, v)
if DEBUG_CORE:
print("o:", o, o.shape)

Expand Down Expand Up @@ -176,7 +200,7 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout

# Call the core attention function
o, softmax_lse, sd_mask = attention_forward_core_ref_impl(
q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2
q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2
)

if group_size != 1:
Expand Down Expand Up @@ -278,7 +302,7 @@ def attention_varlen_forward_pytorch_ref_impl(
v_i = v_i.reshape(nheads_k, seqlen_k, head_dim)

# Call the core attention function for this sequence
o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2)
o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2)

# Reshape outputs back to original dimensions
if group_size != 1:
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4

# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html
ATOL, RTOL = 1e-2, 0 # old standard. maybe to lose.
ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose.
# ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues
# ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs
# ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs
Expand Down

0 comments on commit c65af82

Please sign in to comment.