Skip to content

Commit

Permalink
good output but kvcache is not updated properly
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 21, 2024
1 parent 305d30d commit 8a488ad
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 46 deletions.
65 changes: 43 additions & 22 deletions flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import triton.language as tl
from flash_attn.flash_attn_triton_kernel_prefill_amd import MetaData

DEBUG = False
DEBUG = True

def _strides(x: torch.Tensor, *stride_names: str):
if x is None:
Expand Down Expand Up @@ -639,7 +639,6 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
split_k = max(split_k, 1)
return split_k


def pad_to_power_of_2(tensor, dim=-1):
"""Pad the last dimension of the tensor to the next power of 2."""
current_size = tensor.size(dim)
Expand Down Expand Up @@ -678,22 +677,22 @@ class _attention(torch.autograd.Function):
NAME = "triton_splitKF"

@staticmethod
def forward(cls, q, k, v, input_metadata):
def forward(cls, q, k_cache, v_cache, input_metadata):
if DEBUG:
print()
print("attention_decode.forward")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("k:", k_cache, k_cache.shape)
print("v:", v_cache, v_cache.shape)
print("input_metadata:", input_metadata)

original_layout = input_metadata.layout

# kernels expects "bsghd"
if input_metadata.layout == "bshd":
q=q.unsqueeze(2)
k=k.unsqueeze(2)
v=v.unsqueeze(2)
k_cache=k_cache.unsqueeze(2)
v_cache=v_cache.unsqueeze(2)

if input_metadata.new_kv:
input_metadata.k_new = input_metadata.k_new.unsqueeze(2)
Expand All @@ -702,8 +701,8 @@ def forward(cls, q, k, v, input_metadata):
input_metadata.layout = "bsghd"
elif input_metadata.layout == "bhsd":
q=q.permute(0, 2, 1, 3).unsqueeze(2)
k=k.permute(0, 2, 1, 3).unsqueeze(2)
v=v.permute(0, 2, 1, 3).unsqueeze(2)
k_cache=k_cache.permute(0, 2, 1, 3).unsqueeze(2)
v_cache=v_cache.permute(0, 2, 1, 3).unsqueeze(2)
if input_metadata.new_kv:
input_metadata.k_new = input_metadata.k_new.permute(0, 2, 1, 3).unsqueeze(2)
input_metadata.v_new = input_metadata.v_new.permute(0, 2, 1, 3).unsqueeze(2)
Expand All @@ -724,18 +723,33 @@ def forward(cls, q, k, v, input_metadata):
if needs_padding(original_dmodel):
# Pad q, k, and v to the next power of 2
q = pad_to_power_of_2(q)
k = pad_to_power_of_2(k)
v = pad_to_power_of_2(v)
k = pad_to_power_of_2(k_cache)
v = pad_to_power_of_2(v_cache)
if input_metadata.new_kv:
input_metadata.k_new = pad_to_power_of_2(input_metadata.k_new)
input_metadata.v_new = pad_to_power_of_2(input_metadata.v_new)

input_metadata.dmodel = q.shape[-1]

else:
k = k_cache
v = v_cache

print("after padding")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("input_metadata:", input_metadata)



# get dims
batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
_, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape
_, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape

print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, group_q={n_group_q} heads_per_group_q = {heads_per_group_q}, dim_q = {dim_q}")
print(f"batch_size = {batch_size}, seqlen_k = {seqlen_k}, group_k={n_group_k} heads_per_group_k = {heads_per_group_k}, dim_k = {dim_k}")
if DEBUG:
print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, group_q={n_group_q} heads_per_group_q = {heads_per_group_q}, dim_q = {dim_q}")
print(f"batch_size = {batch_size}, seqlen_k = {seqlen_k}, group_k={n_group_k} heads_per_group_k = {heads_per_group_k}, dim_k = {dim_k}")

# Handle MQA/GQA case
if heads_per_group_q > heads_per_group_k:
Expand All @@ -745,10 +759,11 @@ def forward(cls, q, k, v, input_metadata):
else:
input_metadata.is_gqa = False

print("input_metadata.is_gqa:", input_metadata.is_gqa)
print("After MQA/GQA check")
print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, group_q={n_group_q} heads_per_group_q = {heads_per_group_q}, dim_q = {dim_q}")
print(f"batch_size = {batch_size}, seqlen_k = {seqlen_k}, group_k={n_group_k} heads_per_group_k = {heads_per_group_k}, dim_k = {dim_k}")
if DEBUG:
print("input_metadata.is_gqa:", input_metadata.is_gqa)
print("After MQA/GQA check")
print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, group_q={n_group_q} heads_per_group_q = {heads_per_group_q}, dim_q = {dim_q}")
print(f"batch_size = {batch_size}, seqlen_k = {seqlen_k}, group_k={n_group_k} heads_per_group_k = {heads_per_group_k}, dim_k = {dim_k}")

# context
cls.SPLIT_K: Optional[int] = None
Expand All @@ -771,7 +786,8 @@ def forward(cls, q, k, v, input_metadata):
# q = q.transpose(1, 3)
# k = k[:, :, :, :1]
# v = v[:, :, :, :1]
print("mqa_swap_seqlen_head:", mqa_swap_seqlen_head)
if DEBUG:
print("mqa_swap_seqlen_head:", mqa_swap_seqlen_head)
# assert mqa_swap_seqlen_head == False

# Update dim_k if Quantized
Expand Down Expand Up @@ -803,8 +819,9 @@ def forward(cls, q, k, v, input_metadata):
num_warps = 1
split_size = (seqlen_k + split_k - 1) // split_k
use_cache_seqlens = cache_seqlens is not None

print(f"batch_size = {batch_size}, group_q = {n_group_q}, heads_per_group_q = {heads_per_group_q}, split_k = {split_k}, seqlen_q_ceil = {seqlen_q_ceil}, dim_q = {dim_q}, num_of_wgs = {n_group_q * n_group_q * heads_per_group_q * split_k}")

if DEBUG:
print(f"batch_size = {batch_size}, group_q = {n_group_q}, heads_per_group_q = {heads_per_group_q}, split_k = {split_k}, seqlen_q_ceil = {seqlen_q_ceil}, dim_q = {dim_q}, num_of_wgs = {n_group_q * n_group_q * heads_per_group_q * split_k}")

if DEBUG:
print("q:", q, q.shape)
Expand Down Expand Up @@ -879,7 +896,8 @@ def forward(cls, q, k, v, input_metadata):
assert out.shape[-1] % k_block_num == 0
k_block_size = out.shape[-1] // k_block_num
grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num)
print("grid:", grid)
if DEBUG:
print("grid:", grid)


if DEBUG:
Expand Down Expand Up @@ -935,6 +953,9 @@ def forward(cls, q, k, v, input_metadata):
out = out.reshape(batch_size, seqlen_q, -1, dim_q)

if needs_padding(original_dmodel):
k_cache.set_(unpad_from_power_of_2(k, original_dmodel))
v_cache.set_(unpad_from_power_of_2(v, original_dmodel))

out = unpad_from_power_of_2(out, original_dmodel)

return out
Expand Down
54 changes: 30 additions & 24 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb

DEBUG = False
DEBUG = True

MAX_HEADDIM_SM8x = 192

Expand Down Expand Up @@ -672,7 +672,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
if False:
qkv = torch.zeros(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype)
for i in range(seqlen):
qkv[:, i, :, :, :] = torch.full((batch_size, 3, nheads, d), i + 1, device=device, dtype=dtype)
qkv[:, i, :, :, :] = torch.full((batch_size, 3, nheads, d), i, device=device, dtype=dtype)
qkv.requires_grad_(True)
else:
qkv = torch.randn(
Expand Down Expand Up @@ -2161,24 +2161,26 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("has_leftpad", [True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("d", [17, 80])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(4, 4),
# (1, 128),
# (1, 339),
# (3, 1024),
# (64, 800),
# (64, 256),
# (3, 799),
# (64, 2048),
# (16, 20000),
# (1, 128 * 1024),
# (16, 128 * 1024),
(128, 128),
],
)
Expand Down Expand Up @@ -2233,8 +2235,8 @@ def test_flash_attn_kvcache(
if has_leftpad == True:
pytest.skip("cache_leftpad not supported on AMD yet")

if skip_config(seqlen_q, seqlen_k, d):
pytest.skip("Randomly skipping this configuration to limited test time")
# if skip_config(seqlen_q, seqlen_k, d):
# pytest.skip("Randomly skipping this configuration to limited test time")

if seqlen_q > seqlen_k and new_kv:
pytest.skip()
Expand All @@ -2247,9 +2249,13 @@ def test_flash_attn_kvcache(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
if True:
nheads = 1
batch_size = 1
else:
nheads = 6
batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 6

if DEBUG:
print("nheads_q:", nheads)
Expand All @@ -2260,10 +2266,10 @@ def test_flash_attn_kvcache(
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
if False:
if True:
q = torch.zeros(batch_size_cache, seqlen_q, nheads_k, d, device=device, dtype=dtype)
for i in range(seqlen_q):
q[:, i, :, :] = torch.full((batch_size_cache, nheads_k, d), i + 1, device=device, dtype=dtype)
q[:, i, :, :] = torch.full((batch_size_cache, nheads_k, d), i, device=device, dtype=dtype)
else:
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
Expand All @@ -2279,15 +2285,15 @@ def test_flash_attn_kvcache(
v_cache = torch.zeros(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)

for i in range(nheads_k):
k_cache[:, :, i, :] = torch.full((batch_size_cache, seqlen_k, d), i + 1, device=device, dtype=dtype)
v_cache[:, :, i, :] = torch.full((batch_size_cache, seqlen_k, d), i + 1, device=device, dtype=dtype)
elif False:
k_cache[:, :, i, :] = torch.full((batch_size_cache, seqlen_k, d), i, device=device, dtype=dtype)
v_cache[:, :, i, :] = torch.full((batch_size_cache, seqlen_k, d), i, device=device, dtype=dtype)
elif True:
k_cache = torch.zeros(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.zeros(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)

for i in range(seqlen_k):
k_cache[:, i, :, :] = torch.full((batch_size_cache, nheads_k, d), i + 1, device=device, dtype=dtype)
v_cache[:, i, :, :] = torch.full((batch_size_cache, nheads_k, d), i + 1, device=device, dtype=dtype)
k_cache[:, i, :, :] = torch.full((batch_size_cache, nheads_k, d), i, device=device, dtype=dtype)
v_cache[:, i, :, :] = torch.full((batch_size_cache, nheads_k, d), i, device=device, dtype=dtype)

elif False:
k_cache = torch.zeros(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
Expand Down

0 comments on commit 8a488ad

Please sign in to comment.