From 740da396c4b1dd1235dca446e39dd68aad487ea3 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 21 Nov 2024 17:37:43 -0500 Subject: [PATCH] Support variable length of page attention --- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 303 ++++++++++++++++++++++---- 1 file changed, 259 insertions(+), 44 deletions(-) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 7e8a347d4..037dec8cb 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -26,6 +26,23 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, false}; // do_fp8_static_quant } +fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_splitkv_traits{head_size, + head_size, + dtype, + true, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + false}; // do_fp8_static_quant +} + fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, bool has_dropout_randval, const mask_info &mask, @@ -142,6 +159,140 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, drop_seed_offset}; } +fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, + const mask_info &mask, + const int b, + const int max_seqlen_q, + const int h, + const int h_k, + const int d, + const int page_block_size, + const int num_splits, + float softmax_scale, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &block_table_, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor lse, + at::Tensor lse_acc, + at::Tensor out_acc) +{ + // q: (total_q, nheads, d) + // k: (total_k, nheads_k, d) + // v: (total_k, nheads_k, d) + // o: (total_q, nheads, d) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (nheads, total_q) + // lse_acc: (nheads, split, total_q) + // o_acc: (nheads, split, total_q, d) + + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + fmha_fwd_splitkv_args args; + args.q_ptr = q.data_ptr(); + args.k_ptr = k.data_ptr(); + args.v_ptr = v.data_ptr(); + args.bias_ptr = nullptr; + args.lse_acc_ptr = lse_acc.data_ptr(); + args.o_acc_ptr = out_acc.data_ptr(); + args.lse_ptr = nullptr; + args.o_ptr = out.data_ptr(); + + if (block_table_.has_value()) + { + auto block_table = block_table_.value(); + args.block_table_ptr = block_table.data_ptr(); + args.batch_stride_block_table = block_table.stride(0); + args.page_block_size = page_block_size; + } + else + { + args.block_table_ptr = nullptr; + args.batch_stride_block_table = 0; + args.page_block_size = 0; + } + + args.cache_batch_idx = nullptr; + + args.seqstart_q_ptr = seqlens_q.data_ptr(); + args.seqstart_k_ptr = seqlens_k.data_ptr(); + args.seqlen_k_ptr = nullptr; + + args.seqlen_q = total_q; + args.seqlen_k = total_k; + args.batch = b; + args.max_seqlen_q = max_seqlen_q; + args.hdim_q = d; + args.hdim_v = d; + args.nhead_q = h; + args.nhead_k = h_k; + args.num_splits = num_splits; + + args.scale_s = softmax_scale; + args.scale_p = 1; + args.scale_o = 1; + + args.batch_stride_q = 0; + args.stride_q = q.stride(0); + args.nhead_stride_q = q.stride(1); + + args.batch_stride_k = 0; + args.stride_k = k.stride(0); + args.nhead_stride_k = k.stride(1); + + args.batch_stride_v = 0; + args.stride_v = v.stride(0); + args.nhead_stride_v = v.stride(1); + + args.batch_stride_o = 0; + args.stride_o = out.stride(0); + args.nhead_stride_o = out.stride(1); + + args.batch_stride_bias = 0; + args.stride_bias = 0; + args.nhead_stride_bias = 0; + + args.batch_stride_lse = 0; + args.nhead_stride_lse = 0; + + args.batch_stride_lse_acc = 0; + args.nhead_stride_lse_acc = lse_acc.stride(0); + args.split_stride_lse_acc = lse_acc.stride(1); + + args.batch_stride_o_acc = 0; + args.nhead_stride_o_acc = out_acc.stride(0); + args.split_stride_o_acc = out_acc.stride(1); + args.stride_o_acc = out_acc.stride(2); + + if (has_lse) { + args.lse_ptr = lse.data_ptr(); + args.batch_stride_lse = 0; + args.nhead_stride_lse = lse.stride(0); + } + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); + args.bias_ptr = alibi_slopes.data_ptr(); + args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + return args; +} + std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. @@ -180,9 +331,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); - // TODO - Support paged_KV + at::Tensor block_table; const bool paged_KV = block_table_.has_value(); - TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet"); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -195,10 +351,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; - const int num_heads_k = k.size(1); + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - const int max_num_blocks_per_seq = 0; - const int num_blocks = 0; + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case @@ -207,7 +365,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si // H/t Daniel Haziza const int total_q = q.size(0); - const int total_k = k.size(0); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256"); @@ -235,11 +392,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - at::Tensor out; if (out_.has_value()) { out = out_.value(); @@ -259,6 +423,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si auto opts = q.options(); bool has_lse = true; bool has_dropout = p_dropout > 0.0f; + if (has_dropout) + TORCH_CHECK(!paged_KV, "Paged KV does not support dropout"); at::Tensor softmax_lse; // TODO - check gradient, only training require lse @@ -280,6 +446,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si if (return_dropout_randval) {p.zero_();} } + int num_splits = 1; + num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits); + TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); + TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); + + auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat)); + auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size}, opts.dtype(at::kFloat)); + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); @@ -295,44 +469,85 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } if (max_seqlen_k > 0) { - auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; - auto traits = - get_ck_fmha_varlen_fwd_traits( - mask, - q_dtype_str, - head_size, - has_dropout, - has_lse, - alibi_slopes_.has_value()); - - auto args = - get_ck_fmha_varlen_fwd_args( - has_lse, - return_dropout_randval, - mask, - batch_size, - max_seqlen_q, - num_heads, - num_heads_k, - head_size, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes_, - out, - softmax_lse, - p, - softmax_scale, - p_dropout, - drop_seed_offset); - - float t = fmha_fwd(traits, args, stream_config); - TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); + if (paged_KV) + { + auto traits = + get_ck_fmha_varlen_fwd_splitkv_traits( + mask, + q_dtype_str, + head_size, + has_lse, + alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_fwd_splitkv_args( + has_lse, + mask, + batch_size, + max_seqlen_q, + num_heads, + num_heads_k, + head_size, + page_block_size, + num_splits, + softmax_scale, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + block_table_, + alibi_slopes_, + out, + softmax_lse, + softmax_lse_accum, + out_accum); + + float t = fmha_fwd_splitkv(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd_splitkv"); + } + else + { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + + auto traits = + get_ck_fmha_varlen_fwd_traits( + mask, + q_dtype_str, + head_size, + has_dropout, + has_lse, + alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + max_seqlen_q, + num_heads, + num_heads_k, + head_size, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed_offset); + + float t = fmha_fwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); + } } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.