diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5184dd99309b..0fd8790e0d29 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; - int num_splits; // number of splits for splitkv AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0dc7de0e9e51..bf6431cf1afb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,8 +135,24 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { @@ -279,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb..eb9e6d5c6246 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d..3e78978c3cc4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index ff7a22d253a5..89a27c4d2b0d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits, bool new_kv, bool is_sm8x) { + int max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea return 1; } +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 0a0328edb005..58f430425187 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,6 +31,7 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, ); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 784335a124c7..82dfa59b8f8e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -123,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 65d19d447387..67d750aeac11 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -116,22 +116,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t out_accum_bytes = 0; size_t seqlens_k_bytes = 0; if (use_flash_attention) { + // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffers - parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, - parameters.head_size, device_prop.multiProcessorCount, 128, false, - device_prop.major == 8 && device_prop.minor > 0); - if (parameters.num_splits > 1) { - // softmax_lse_accum buffer - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); - // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(parameters.head_size, 32); - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); - } + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; // seqlens_k buffer if (past_key != nullptr) { seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e3f53ca6a63c..ebd66d8c6528 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -291,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context);