diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index b37308be4..4c410b35a 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -540,7 +540,7 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: split_k = max(split_k, 1) return split_k -def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): +def attention_decode_forward_triton_impl(q, k, v, k_new, v_new, sm_scale, causal, layout, alibi_slopes, cache_seqlens, cache_batch_idx): # kernel config BLOCK_M = 16 BLOCK_N = 64 @@ -553,16 +553,18 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes q=q.unsqueeze(2) k=k.unsqueeze(2) v=v.unsqueeze(2) - if new_kv: + if k_new is not None: k_new = k_new.unsqueeze(2) + if v_new is not None: v_new = v_new.unsqueeze(2) layout = "bsghd" elif layout == "bhsd": q=q.permute(0, 2, 1, 3).unsqueeze(2) k=k.permute(0, 2, 1, 3).unsqueeze(2) v=v.permute(0, 2, 1, 3).unsqueeze(2) - if new_kv: + if k_new is not None: k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) + if v_new is not None: v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) layout = "bsghd" elif layout == "bsghd": @@ -571,6 +573,9 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes raise ValueError("Layout not given") assert layout == "bsghd" + # check that both are provided or both are none + assert ((k_new is None) and (v_new is None)) or ((k_new is not None) and (v_new is not None)) + # get dims batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 909996654..2054113c4 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -140,6 +140,93 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox return o, softmax_lse, sd_mask +def attention_decode_forward_pytorch_ref_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + cache_seqlens, + cache_batch_idx, + sm_scale, + causal, + layout, + alibi_slopes, + rotary_cos, + rotary_sin, + rotary_interleaved, + use_exp2 +): + if DEBUG: + print() + print("attention_forward_pytorch_ref_impl") + print("q:", q, q.shape) + print("k:", k_cache, k_cache.shape) + print("v:", v_cache, v_cache.shape) + print("k_new:", k_new, k_new.shape if k_new is not None else None) + print("v_new:", v_new, v_new.shape if v_new is not None else None) + print("cache_seqlens:", cache_seqlens) + print("cache_batch_idx:", cache_batch_idx) + print("sm_scale:", sm_scale) + print("causal:", causal) + print("alibi_slopes:", alibi_slopes) + print("layout:", layout) + print("rotary_cos:", rotary_cos) + print("rotary_sin:", rotary_sin) + print("rotary_interleaved:", rotary_interleaved) + print("use_exp2:", use_exp2) + + # Ensure the layout is 'bhsd' + if layout == "bshd": + q = q.transpose(1, 2).contiguous() + k_cache = k_cache.transpose(1, 2).contiguous() + v_cache = v_cache.transpose(1, 2).contiguous() + if k_new is not None: + k_new = k_new.transpose(1, 2).contiguous() + if v_new is not None: + v_new = v_new.transpose(1, 2).contiguous() + elif layout != "bhsd": + raise ValueError(f"Unknown layout {layout}") + + # check that both are provided or both are none + assert ((k_new is None) and (v_new is None)) or ((k_new is not None) and (v_new is not None)) + + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k_cache, seq_len_k_cache, head_dim = k_cache.shape + if k_new: + batch_size, nheads_k_new, seq_len_k_new, head_dim = k_new.shape + + # insert new tensors in cache + # TODO + + # convert to 3d tensors for core impl + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k_cache = k_cache.reshape(batch_size * nheads_k_cache, seq_len_k_cache, head_dim) + v_cache = v_cache.reshape(batch_size * nheads_k_cache, seq_len_k_cache, head_dim) + # if k_new is not None: + # k_new = k_new.reshape(batch_size * nheads_k_new, seq_len_k_new, head_dim) + # if v_new is not None: + # v_new = v_new.reshape(batch_size * nheads_k_new, seq_len_k_new, head_dim) + + + # launch core impl + output, softmax_lse, sd_mask = attention_forward_core_ref_impl( + q, k_cache, v_cache, sm_scale, causal, 0.0, None, None, use_exp2 + ) + + output = output.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k_cache) + + if layout == "bshd": + output = output.transpose(1, 2) + + return output, softmax_lse + + + + def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2): """Compute reference output and softmax_lse using PyTorch's built-in function""" diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 51037f236..797740289 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -3,14 +3,12 @@ from .fwd_prefill import attention_prefill_forward_triton_impl from .bwd_prefill import attention_prefill_backward_triton_impl from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_forward_pytorch_ref_impl +from .fwd_ref import attention_forward_pytorch_ref_impl, attention_decode_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import MetaData, get_shape_from_layout, DEBUG +from .utils import DEBUG, USE_REF, MetaData, get_shape_from_layout from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb -USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') - def fwd(q, k, v, @@ -502,6 +500,20 @@ def fwd_kvcache( rotary_interleaved, num_splits): + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd_kvcache") + print("q:", q, q.shape) + print("k:", k, k.shape if k is not None else None) + print("v:", v, v.shape if v is not None else None) + print("alibi_slopes:", alibi_slopes) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("out:", out) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + if out is None: out = torch.empty_like(q) @@ -513,11 +525,6 @@ def fwd_kvcache( metadata.cache_seqlens = cache_seqlens metadata.cache_batch_idx = cache_batch_idx - if k is not None and v is not None: - metadata.new_kv = True - metadata.seqlen_new = k.shape[1] - metadata.k_new = k - metadata.v_new = v if causal: metadata.need_causal() @@ -563,20 +570,45 @@ def fwd_kvcache( q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) - # launch kernel - # TODO: pass output as an arg. Maybe we are copying output which is causing slow down - output, softmax_lse = attention_decode_forward_triton_impl( - q, - k_cache, - v_cache, - metadata.sm_scale, - metadata.causal, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.new_kv, - metadata.k_new, - metadata.v_new, - ) + if USE_REF: + if DEBUG: + print("Using reference implementation") + output, softmax_lse = attention_decode_forward_pytorch_ref_impl( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + cache_batch_idx, + metadata.sm_scale, + metadata.causal, + metadata.layout, + metadata.alibi_slopes, + metadata.rotary_cos, + metadata.rotary_sin, + metadata.rotary_interleaved, + False + ) + out.copy_(output) + else: + if DEBUG: + print("Using Triton implementation") + + # launch kernel + # TODO: pass output as an arg. Maybe we are copying output which is causing slow down + output, softmax_lse = attention_decode_forward_triton_impl( + q, + k_cache, + v_cache, + metadata.sm_scale, + metadata.causal, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.new_kv, + metadata.k_new, + metadata.v_new, + ) return output, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 343425788..c48635f1c 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -8,15 +8,18 @@ import triton import triton.language as tl +# global variables AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') +USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if USE_TRITON_ROCM: # TODO remove this random.seed(42) DROPOUT_USE_PYTORCH = False DROPOUT_DUMP = False +# Flash Attention Metadata class MetaData(): cu_seqlens_q = None cu_seqlens_k = None @@ -30,10 +33,6 @@ class MetaData(): layout = None cache_seqlens = None cache_batch_idx = None - new_kv = False - seqlen_new = None - k_new = None - v_new = None return_scores= False dropout_p= 0.0 philox_seed, philox_offset = None, None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. @@ -43,7 +42,7 @@ class MetaData(): rotary_cos = None rotary_interleaved = False rotary_conjunction = False - + is_decode = False def __repr__(self) -> str: return (f"MetaData(\n" @@ -60,10 +59,6 @@ def __repr__(self) -> str: f" layout={self.layout},\n" f" cache_seqlens={self.cache_seqlens},\n" f" cache_batch_idx={self.cache_batch_idx},\n" - f" new_kv={self.new_kv},\n" - f" seqlen_new={self.seqlen_new},\n" - f" k_new={self.k_new},\n" - f" v_new={self.v_new},\n" f" dropout_p={self.dropout_p},\n" f" return_scores={self.return_scores}\n" f")") @@ -71,6 +66,9 @@ def __repr__(self) -> str: def __init__(self, sm_scale=1.0): self.sm_scale = sm_scale + def is_decode(self): + is_decode = True + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): self.varlen = True self.layout = 'thd' diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index 7e300687c..ba46c57e7 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1842,24 +1842,24 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("num_splits", [1, 0]) -# @pytest.mark.parametrize("num_splits", [1]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("new_kv", [False, True]) -# @pytest.mark.parametrize("new_kv", [False]) -@pytest.mark.parametrize("alibi", [False, True]) -# @pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("num_splits", [1, 0]) +@pytest.mark.parametrize("num_splits", [1]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("rotary_interleaved", [False, True]) -# @pytest.mark.parametrize("rotary_interleaved", [False]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -# @pytest.mark.parametrize("rotary_fraction", [0.0]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [False]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None]) @@ -1867,25 +1867,27 @@ def test_flash_attn_varlen_causal( @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [32]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), - (128, 128), + (4, 4) + # (1, 128), + # (1, 339), + # (3, 1024), + # (64, 800), + # (64, 256), + # (3, 799), + # (64, 2048), + # (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + # (128, 128), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])