diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index cacd65313ebc..3b7f980ba188 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -149,8 +149,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == relative_position_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); if (use_causal_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { @@ -171,8 +171,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == present && nullptr == relative_position_bias && parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. @@ -184,8 +184,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } // In case some kernel not loaded due to shared memory limit, we need to double check here. - const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length); - if (fused_fp16_runner_->isValid(S)) { + const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length); + if (fused_fp16_runner_->IsValid(normalized_seq_len)) { fused_runner = fused_fp16_runner_.get(); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 150079cdf157..997493acd9cb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -245,12 +245,10 @@ Status FusedTrtSelfAttention( FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); - const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); + const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length); // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. - const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - - fused_fp16_runner->setup(S, B); + const int b = (nullptr == data.mask_index ? batch_size : 2 * batch_size); if (!causal) { assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); @@ -261,12 +259,12 @@ Status FusedTrtSelfAttention( packed_qkv = data.query; } - fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream); DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } else { assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); + fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream); DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index b96140f3897f..663bd020ddac 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -193,8 +193,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { @@ -206,8 +206,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } // In case some kernel not loaded due to shared memory limit, we need to double check here. - const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length); - if (fused_fp16_runner_->isValid(S)) { + const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length); + if (fused_fp16_runner_->IsValid(normalized_seq_len)) { fused_runner = fused_fp16_runner_.get(); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index a1149ddbf99f..d1c6993d48e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -55,11 +55,11 @@ MHARunner* TrtFusedAttention::GetFusedRunner(const cudaDeviceProp& device_pro // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; - bool is_fMHA_supported = FusedMHARunnerFP16v2::is_supported(sm, - parameters.head_size, - parameters.sequence_length, - enable_trt_flash_attention_, - false /*causal*/); + bool is_fMHA_supported = FusedMHARunnerFP16v2::IsSupported(sm, + parameters.head_size, + parameters.sequence_length, + enable_trt_flash_attention_, + false /*causal*/); if (!is_fMHA_supported) { return fused_runner; @@ -72,8 +72,8 @@ MHARunner* TrtFusedAttention::GetFusedRunner(const cudaDeviceProp& device_pro } // In case some kernel not loaded due to shared memory limit, we need to double check here. - const int S = fused_fp16_runner_->getSFromMaxSeqLen(parameters.sequence_length); - if (fused_fp16_runner_->isValid(S)) { + const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(parameters.sequence_length); + if (fused_fp16_runner_->IsValid(normalized_seq_len)) { fused_runner = fused_fp16_runner_.get(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index db9f30c25c01..ac2cb5165a94 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -459,10 +459,9 @@ Status FusedScaledDotProductAttention( parameters.token_count, stream); FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); - const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - fused_fp16_runner->setup(S, batch_size); - - fused_fp16_runner->run(data.workspace, data.cumulative_sequence_length, data.output, stream); + const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length); + fused_fp16_runner->Run(batch_size, normalized_seq_len, + data.workspace, data.cumulative_sequence_length, data.output, stream); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 3e168189be3d..b4ca0194b08b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -575,10 +575,8 @@ Status FusedAttentionTrt( } FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); - const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - fused_fp16_runner->setup(S, batch_size); - - fused_fp16_runner->run(qkv, data.cumulative_sequence_length, data.output, stream); + const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length); + fused_fp16_runner->Run(batch_size, normalized_seq_len, qkv, data.cumulative_sequence_length, data.output, stream); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 4a4e3eeecf64..8af28e874729 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -14,6 +14,10 @@ * limitations under the License. */ +// Modifications: Update interface and implmentation to be thread-safe +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/flash_attention/fmha_flash_attention.h" @@ -34,28 +38,28 @@ void set_alpha_fp16(uint32_t& alpha, float norm) { alpha = temp.u32; } -class FusedMHARunnerFP16v2::mhaImpl { +class FusedMHARunnerFP16v2::FmhaImpl { public: - mhaImpl(FusedMHARunnerFP16v2* interface) - : interface(interface), - sm(interface->mSm), - xmmaKernel(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) { + FmhaImpl(FusedMHARunnerFP16v2* interface, int sm) + : interface_(interface), + sm_(sm), + xmma_kernel_(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) { ORT_ENFORCE((sm == kSM_70 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89), "Unsupported architecture"); - flash_attention_kernel = nullptr; - if (interface->mEnableFlashAttention) { - flash_attention_kernel = get_flash_attention_kernels(DATA_TYPE_FP16, sm); + flash_kernel_ = nullptr; + if (interface_->enable_flash_attention_) { + flash_kernel_ = get_flash_attention_kernels(DATA_TYPE_FP16, sm); } - - params.clear(); } - ~mhaImpl() {} + ~FmhaImpl() {} - void setup(const int seq_len, const int B) { - // For bert and vit, use flash attention when sequence length is larger than the threshold. - use_flash_attention = is_flash_attention(seq_len); + void Setup(Fused_multihead_attention_params_v2& params, + int sequence_length, // normalized sequence length + int batch_size, + bool& use_flash_attention) const { + use_flash_attention = UseFlashAttention(sequence_length); params.force_unroll = use_flash_attention; @@ -67,27 +71,27 @@ class FusedMHARunnerFP16v2::mhaImpl { warps_m = 4; warps_n = 1; } else { - if (sm == 70) { - if (seq_len == 64 || seq_len == 96) { + if (sm_ == 70) { + if (sequence_length == 64 || sequence_length == 96) { warps_m = 2; warps_n = 2; - } else if (seq_len == 128) { + } else if (sequence_length == 128) { warps_m = 1; warps_n = 4; - } else if (seq_len == 256 || seq_len == 384) { + } else if (sequence_length == 256 || sequence_length == 384) { warps_m = 1; warps_n = 8; } else { ORT_ENFORCE(false, "Unsupported sequence length"); } } else { - if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) { + if (sequence_length == 32 || sequence_length == 64 || sequence_length == 96 || sequence_length == 128) { warps_m = 2; warps_n = 2; - } else if (seq_len == 192 || seq_len == 256) { + } else if (sequence_length == 192 || sequence_length == 256) { warps_m = 1; warps_n = 4; - } else if (seq_len == 384) { + } else if (sequence_length == 384) { warps_m = 1; warps_n = 8; } else { @@ -97,11 +101,11 @@ class FusedMHARunnerFP16v2::mhaImpl { } // The number of threads per CTA. - threads_per_cta = warps_m * warps_n * warps_k * 32; + size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. - xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m); + size_t xmmas_m = (sequence_length + 16 * warps_m - 1) / (16 * warps_m); - const float scale_bmm1 = interface->mScale; + const float scale_bmm1 = interface_->scale_; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -109,20 +113,21 @@ class FusedMHARunnerFP16v2::mhaImpl { set_alpha_fp16(params.scale_softmax, scale_softmax); set_alpha_fp16(params.scale_bmm2, scale_bmm2); - params.b = B; - params.h = interface->mNumHeads; - params.s = seq_len; - params.d = interface->mHeadSize; + params.b = batch_size; + params.h = interface_->num_heads_; + params.s = sequence_length; + params.d = interface_->head_size_; - params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); + params.qkv_stride_in_bytes = 3 * interface_->num_heads_ * interface_->head_size_ * sizeof(half); params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t); - params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half); - - has_causal_mask = false; + params.o_stride_in_bytes = interface_->num_heads_ * interface_->head_size_ * sizeof(half); } - void setup_causal_masked_fmha(const int seq_len, const int B) { - const float scale_bmm1 = interface->mScale; + void SetupCausal(Fused_multihead_attention_params_v2& params, + int sequence_length, // normalized sequence length + int batch_size, + bool& use_flash_attention) const { + const float scale_bmm1 = interface_->scale_; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -130,16 +135,17 @@ class FusedMHARunnerFP16v2::mhaImpl { set_alpha_fp16(params.scale_softmax, scale_softmax); set_alpha_fp16(params.scale_bmm2, scale_bmm2); - params.b = B; - params.h = interface->mNumHeads; - params.s = seq_len; - params.d = interface->mHeadSize; + params.b = batch_size; + params.h = interface_->num_heads_; + params.s = sequence_length; + params.d = interface_->head_size_; + + params.qkv_stride_in_bytes = 3 * interface_->num_heads_ * interface_->head_size_ * sizeof(half); + params.o_stride_in_bytes = interface_->num_heads_ * interface_->head_size_ * sizeof(half); - params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); - params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half); + use_flash_attention = interface_->enable_flash_attention_; - // fallback to original fmha_v2 when head_size <= 64 and seq_len <- 128 - use_flash_attention = interface->mEnableFlashAttention; + // fallback to original fmha_v2 when head_size <= 64 and sequence_length <= 128 if (params.d <= 64 && params.s <= 128) { use_flash_attention = false; // get max sequence length @@ -152,97 +158,87 @@ class FusedMHARunnerFP16v2::mhaImpl { // set flags params.force_unroll = use_flash_attention; - has_causal_mask = true; } - void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) { + void Run(Fused_multihead_attention_params_v2& params, + const void* input, + const void* cu_seqlens, + void* output, + cudaStream_t stream, + bool use_flash_attention, + bool has_causal_mask) const { params.qkv_ptr = const_cast(input); params.o_ptr = output; params.cu_seqlens = static_cast(const_cast(cu_seqlens)); - if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) { - flash_attention_kernel->run(params, stream); + if (use_flash_attention && flash_kernel_ != nullptr && !has_causal_mask) { + flash_kernel_->run(params, stream); } else { - xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask); + xmma_kernel_->run(params, stream, use_flash_attention, has_causal_mask); } CUDA_CALL_THROW(cudaPeekAtLastError()); } - bool isValid(int s) const { - if (is_flash_attention(s)) { - return (flash_attention_kernel != nullptr) && flash_attention_kernel->isValid(s); + bool IsValid(int sequence_length) const { + if (UseFlashAttention(sequence_length)) { + return (flash_kernel_ != nullptr) && flash_kernel_->isValid(sequence_length); } - return xmmaKernel->isValid(s); + return xmma_kernel_->isValid(sequence_length); } - int getSFromMaxSeqLen(const int max_seq_len) const { - if (is_flash_attention(max_seq_len)) { + int NormalizeSequenceLength(int max_seq_len) const { + if (UseFlashAttention(max_seq_len)) { return max_seq_len; } - int seq_len = max_seq_len; + int sequence_length = max_seq_len; if (max_seq_len <= 32) { - seq_len = (sm == 70) ? 64 : 32; + sequence_length = (sm_ == 70) ? 64 : 32; } else if (max_seq_len <= 64) { - seq_len = 64; + sequence_length = 64; } else if (max_seq_len <= 96) { - seq_len = 96; + sequence_length = 96; } else if (max_seq_len <= 128) { - seq_len = 128; + sequence_length = 128; } else if (max_seq_len <= 192) { - seq_len = (sm == 70) ? 256 : 192; + sequence_length = (sm_ == 70) ? 256 : 192; } else if (max_seq_len <= 256) { - seq_len = 256; + sequence_length = 256; } else if (max_seq_len <= 384) { - seq_len = 384; + sequence_length = 384; } - return seq_len; + return sequence_length; } protected: - bool is_flash_attention(const int seq_len) const { - ORT_ENFORCE(interface->mHasCausalMask == false); - return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention; + bool UseFlashAttention(int sequence_length) const { + ORT_ENFORCE(interface_->is_causal_ == false); + return interface_->enable_flash_attention_ && sequence_length >= kMinSequenceLengthFlashAttention; } private: - FusedMHARunnerFP16v2* interface; - Fused_multihead_attention_params_v2 params; - int sm; - const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel; - const FusedMultiHeadFlashAttentionKernel* flash_attention_kernel; - size_t xmmas_m; - size_t threads_per_cta; - bool use_flash_attention = false; - bool has_causal_mask = false; + FusedMHARunnerFP16v2* interface_; + int sm_; + const FusedMultiHeadAttentionXMMAKernelV2* xmma_kernel_; + const FusedMultiHeadFlashAttentionKernel* flash_kernel_; }; -FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, - const int headSize, - const int sm, - bool causal_mask, +FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(int num_heads, + int head_size, + int sm, + bool causal, bool enable_flash_attention, - const float scale) - : MHARunner(numHeads, headSize, 2, causal_mask, scale), - mSm(sm), - mEnableFlashAttention(enable_flash_attention), - pimpl(new mhaImpl(this)) { + float scale) + : MHARunner(num_heads, head_size, causal, scale), + enable_flash_attention_(enable_flash_attention), + impl_(new FmhaImpl(this, sm)) { } -void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) { - MHARunner::setup(seq_len, B); - if (mHasCausalMask) { - pimpl->setup_causal_masked_fmha(seq_len, B); - } else { - pimpl->setup(seq_len, B); - } -} - -bool FusedMHARunnerFP16v2::is_supported(int sm, int head_size, int sequence_length, - bool enable_flash_attention, bool causal) { +bool FusedMHARunnerFP16v2::IsSupported(int sm, int head_size, int sequence_length, + bool enable_flash_attention, bool causal) { if (causal) { if (!(sm == kSM_70 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)) { return false; @@ -284,34 +280,44 @@ bool FusedMHARunnerFP16v2::is_supported(int sm, int head_size, int sequence_leng return sequence_length <= max_sequence_length; } -size_t FusedMHARunnerFP16v2::getWorkspaceSize() const { - return 0; -} +void FusedMHARunnerFP16v2::Run(int batch_size, + int normalized_sequence_length, + const void* input, + const void* cu_seqlens, + void* output, + cudaStream_t stream) const { + Fused_multihead_attention_params_v2 params; + bool use_flash_attention = false; + if (is_causal_) { + impl_->SetupCausal(params, normalized_sequence_length, batch_size, use_flash_attention); + } else { + impl_->Setup(params, normalized_sequence_length, batch_size, use_flash_attention); + } -void FusedMHARunnerFP16v2::run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) { - pimpl->run(input, cu_seqlens, output, stream); + impl_->Run(params, input, cu_seqlens, output, stream, use_flash_attention, is_causal_); } -bool FusedMHARunnerFP16v2::isValid(int s) const { - return pimpl->isValid(s); +bool FusedMHARunnerFP16v2::IsValid(int normalized_sequence_length) const { + return impl_->IsValid(normalized_sequence_length); } -int FusedMHARunnerFP16v2::getSFromMaxSeqLen(const int max_seq_len) const { - return pimpl->getSFromMaxSeqLen(max_seq_len); +int FusedMHARunnerFP16v2::NormalizeSequenceLength(int max_seq_len) const { + return impl_->NormalizeSequenceLength(max_seq_len); } -std::unique_ptr FusedMHARunnerFP16v2::Create(const int numHeads, - const int headSize, - const int sm, - bool causal_mask, - bool enable_flash_attention, - const float scale) { +std::unique_ptr FusedMHARunnerFP16v2::Create(int num_heads, + int head_size, + int sm, + bool causal, + bool enable_flash_attention, + const float scale) { #ifdef _MSC_VER - return std::make_unique(numHeads, headSize, sm, causal_mask, enable_flash_attention, scale); + return std::make_unique(num_heads, head_size, sm, causal, enable_flash_attention, scale); #else - // Linux build has error using make_unique: invalid application of ‘sizeof’ to incomplete type ‘onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl + // Linux build has error using make_unique: invalid application of ‘sizeof’ to + // incomplete type ‘onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::FmhaImpl std::unique_ptr runner; - runner.reset(new FusedMHARunnerFP16v2(numHeads, headSize, sm, causal_mask, enable_flash_attention, scale)); + runner.reset(new FusedMHARunnerFP16v2(num_heads, head_size, sm, causal, enable_flash_attention, scale)); return runner; #endif } diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h index f7c1dc85361d..82914b07e524 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h @@ -14,6 +14,10 @@ * limitations under the License. */ +// Modifications: Update interface and implmentation to be thread-safe +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include @@ -25,103 +29,70 @@ namespace cuda { constexpr int kMinSequenceLengthFlashAttention = 385; -// Multi-Head Attention runner class MHARunner { public: - MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale) - : mS(0), - mB(0), - mOmatSize(0), - mNumMats(0), - mNumHeads(numHeads), - mHeadSize(headSize), - mWordSize(wordSize), - mLdQKV(0), - mStrideQKV(0), - mLdOut(0), - mStrideOut(0), - mScale(scale == 0.0f ? 1.f / sqrtf(static_cast(headSize)) - : scale), - mHasCausalMask(causal_mask) { + MHARunner(int num_heads, int head_size, bool causal, float scale) + : num_heads_(num_heads), + head_size_(head_size), + scale_(scale == 0.0f ? 1.f / sqrtf(static_cast(head_size)) : scale), + is_causal_(causal) { } virtual ~MHARunner() = default; - virtual void setup(const int S, const int B) { - ORT_ENFORCE(S > 0); - ORT_ENFORCE(B > 0); - - mB = B; - mS = S; - - mLdQKV = 3 * B * mNumHeads * mHeadSize; - mStrideQKV = 3 * mHeadSize; - - mLdOut = B * mNumHeads * mHeadSize; - mStrideOut = mHeadSize; - mOmatSize = S * S; - mNumMats = B * mNumHeads; - } - - virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0; - - virtual size_t getWorkspaceSize() const = 0; + virtual int NormalizeSequenceLength(int max_seq_len) const = 0; - virtual bool isValid(int s) const = 0; + virtual bool IsValid(int normalized_sequence_length) const = 0; - virtual int getSFromMaxSeqLen(const int max_seq_len) const = 0; + virtual void Run(int batch_size, + int normalized_sequence_length, + const void* input, + const void* cu_seqlens, + void* output, + cudaStream_t stream) const = 0; protected: - int mS; - int mB; - int mOmatSize; - int mNumMats; - int mNumHeads; - int mHeadSize; - int mWordSize; - int mLdQKV; - int mStrideQKV; - int mLdOut; - int mStrideOut; - - float mScale; - bool mHasCausalMask; + int num_heads_; + int head_size_; + float scale_; + bool is_causal_; }; class FusedMHARunnerFP16v2 : public MHARunner { public: - FusedMHARunnerFP16v2(const int numHeads, - const int headSize, - const int sm, - bool causal_mask, + FusedMHARunnerFP16v2(int num_heads, + int head_size, + int sm, + bool causal, bool enable_flash_attention, - const float scale); - ~FusedMHARunnerFP16v2() = default; // for pimpl - - virtual void setup(const int S, const int B) override; + float scale); - static bool is_supported(int sm, int head_size, int sequence_length, bool enable_flash_attention, bool causal); + ~FusedMHARunnerFP16v2() = default; // for impl_ - void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override; + static bool IsSupported(int sm, int head_size, int sequence_length, bool enable_flash_attention, bool causal); - size_t getWorkspaceSize() const override; + static std::unique_ptr Create(int num_heads, + int head_size, + int sm, + bool causal, + bool enable_flash_attention, + float scale); - bool isValid(int s) const override; + bool IsValid(int normalized_sequence_length) const override; - int getSFromMaxSeqLen(const int max_seq_len) const override; + int NormalizeSequenceLength(int max_seq_len) const override; - static std::unique_ptr Create(const int numHeads, - const int headSize, - const int sm, - bool causal_mask, - bool enable_flash_attention, - const float scale); + void Run(int batch_size, + int normalized_sequence_length, + const void* input, + const void* cu_seqlens, + void* output, + cudaStream_t stream) const override; private: - int mSm; - bool mEnableFlashAttention; - class mhaImpl; - std::unique_ptr pimpl; + bool enable_flash_attention_; + class FmhaImpl; + std::unique_ptr impl_; }; } // namespace cuda diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 797461bae2ef..111c417479d2 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -156,6 +156,49 @@ def shape_dict(self, input_format=None): ) return shapes + def symbolic_shape_dict(self, input_format=None): + input_format = input_format or self.input_format + if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + # cross attention does not have past state + return { + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", self.num_heads, "sequence_length", self.head_size), + "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + + if self.use_kv_cache: + shapes = { + "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), + "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), + "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), + } + else: + shapes = { + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + + if input_format == InputFormats.QKV_BSN3H: + shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)}) + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes.update( + { + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), + } + ) + else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH + shapes.update( + { + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + ) + return shapes + def random_inputs(self, seed: int = 123): device = self.device dtype = self.dtype @@ -215,7 +258,7 @@ def random_inputs(self, seed: int = 123): def get_input_output_names(self): if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - return ["query", "key"], ["output"] + return ["query", "key", "value"], ["output"] if self.input_format == InputFormats.QKV_BSN3H: inputs, outputs = ["query"], ["output"] @@ -235,7 +278,7 @@ def fill_optional_mha_inputs(input_names): return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:] -def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig): +def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False): input_names, output_names = config.get_input_output_names() float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT @@ -252,7 +295,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig): ), ] - shape_dict = config.shape_dict() + shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict() inputs = [ helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) for input_name in input_names diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 5335e7115ad7..ff473cc2ced9 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -7,17 +7,39 @@ Test MultiHeadAttention operator for CUDA and CPU. """ +import concurrent.futures import itertools import unittest -from typing import Optional +from enum import IntEnum +from typing import Dict, List, Optional import numpy import torch -from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention +from benchmark_mha import ( + InputFormats, + MultiHeadAttentionConfig, + OrtMultiHeadAttention, + create_multi_head_attention_onnx_model, +) from einops import rearrange from parameterized import parameterized import onnxruntime +from onnxruntime import InferenceSession + + +class SdpaKernel(IntEnum): + """Bit flags for sdpa_kernel CUDA provider option""" + + DEFAULT = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + TRT_FUSED_ATTENTION = 4 + CUDNN_FLASH_ATTENTION = 8 + MATH = 16 + TRT_FLASH_ATTENTION = 32 + TRT_CROSS_ATTENTION = 64 + TRT_CAUSAL_ATTENTION = 128 def attention_reference( @@ -105,9 +127,16 @@ def mha_with_past_reference( def get_provider_support_info(provider: str, use_kv_cache: bool): if provider == "CUDAExecutionProvider": - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] if not use_kv_cache: - formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH) + formats = [ + InputFormats.Q_K_V_BSNH_BSNH_BSNH, + InputFormats.Q_KV_BSNH_BSN2H, + InputFormats.QKV_BSN3H, + InputFormats.Q_K_V_BSNH_BNSH_BNSH, + ] + else: + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + device_id = torch.cuda.current_device() device = torch.device("cuda", device_id) dtype = torch.float16 @@ -121,15 +150,16 @@ def get_provider_support_info(provider: str, use_kv_cache: bool): return device, dtype, formats -def has_cuda_support(): +def get_compute_capability(): if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - major, _ = torch.cuda.get_device_capability() - return major >= 6 - return False + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm + return 0 def no_kv_cache_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and not has_cuda_support(): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return yield @@ -192,7 +222,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): def kv_cache_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and not has_cuda_support(): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return yield @@ -262,6 +292,92 @@ def mha_test_cases(provider: str, comprehensive: bool): ) +def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: + return + yield + + batch_sizes = [1, 2] + sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 400] if comprehensive else [1, 64, 128, 256] + heads = [4] + head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64] + + device, dtype, formats = get_provider_support_info(provider, False) + + for format in formats: + for causal in [True, False]: + for num_heads in heads: + for head_size in head_sizes: + configs = [] # list of configurations to run in parallel + for batch_size in batch_sizes: + for sequence_length in sequence_lengths: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + ) + configs.append(config) + yield configs + + +def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: + return + yield + + batch_sizes = [1, 2] + sequence_lengths = [1, 32, 127, 128, 383, 384, 400] if comprehensive else [1, 32, 127, 128] + heads = [4] + head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64] + + sequence_length = 1 + device, dtype, formats = get_provider_support_info(provider, True) + + for format in formats: + for causal in [True, False]: + for num_heads in heads: + for head_size in head_sizes: + configs = [] + for batch_size in batch_sizes: + for past_sequence_length in sequence_lengths: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_sequence_length, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + share_past_present_buffer=False, + input_format=format, + ) + configs.append(config) + yield configs + + +def multi_thread_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_multi_thread_test_cases(provider, comprehensive), + kv_cache_multi_thread_test_cases(provider, comprehensive), + ) + + def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) @@ -346,20 +462,189 @@ def parity_check_mha( ) +def parity_check_mha_multi_threading( + test_inputs: List[Dict], + rtol: float = 1e-3, + atol: float = 1e-3, + sdpa_kernel: int = SdpaKernel.DEFAULT, + max_threads: int = 5, + verbose: bool = False, +): + # Use the first config to create a session, which is shared by all configs to run in parallel. + config = test_inputs[0]["config"] + # For now, MHA CUDA kernel does not support causal so skip such test cases. + if config.causal and config.provider == "CUDAExecutionProvider": + return None + # Some kernel does not support certain input format. + if sdpa_kernel not in [ + SdpaKernel.DEFAULT, + SdpaKernel.FLASH_ATTENTION, + SdpaKernel.EFFICIENT_ATTENTION, + ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: + return None + if verbose: + print(f"create a shared session with {vars(config)}") + onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True) + if config.provider == "CUDAExecutionProvider": + provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)} + providers = [(config.provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + ort_session = InferenceSession(onnx_model_str, providers=providers) + + def convert_to_ort_inputs(feed_dict): + ort_inputs = {} + + for k, v in feed_dict.items(): + if isinstance(v, numpy.ndarray): + ort_inputs[k] = v + else: + ort_inputs[k] = v.detach().cpu().numpy() + return ort_inputs + + def check_parity_with_config(i: int): + config = test_inputs[i]["config"] + if verbose: + print(f"Thread {i} with {vars(config)}") + + ort_inputs = test_inputs[i]["ort_inputs"] + + if verbose: + print(f"Thread {i} ort inputs: {ort_inputs}") + ort_outputs = ort_session.run(None, convert_to_ort_inputs(ort_inputs)) + out = numpy.reshape( + ort_outputs[0], (config.batch_size, config.sequence_length, config.num_heads, config.head_size) + ) + + # Create reference inputs + config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH + ref_inputs = test_inputs[i]["ref_inputs"] + if verbose: + print(f"Thread {i} ref inputs: {ref_inputs}") + q = ( + ref_inputs["query"] + .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + .transpose(1, 2) + ) + k = ( + ref_inputs["key"] + .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + .transpose(1, 2) + ) + v = ( + ref_inputs["value"] + .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + .transpose(1, 2) + ) + + mask = None + if config.causal: + mask = causal_mask(config.sequence_length, config.total_sequence_length, device=config.device) + + k_cache = None + v_cache = None + if config.use_kv_cache: + past_k = ref_inputs["past_key"] + past_v = ref_inputs["past_value"] + out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) + else: + out_ref = attention_reference(config.head_size, q, k, v, mask=mask) + + try: + numpy.testing.assert_allclose( + out, + out_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=f"output not close: {config=}", + ) + + if config.use_kv_cache: + present_key = ort_outputs[1] + numpy.testing.assert_allclose( + k_cache.detach().cpu().numpy(), + present_key, + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=f"present_key not close: {config=}", + ) + + present_value = ort_outputs[2] + numpy.testing.assert_allclose( + v_cache.detach().cpu().numpy(), + present_value, + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=f"present_value not close: {config=}", + ) + except AssertionError as e: + print(f"Failed with {vars(config)}: {e}") + return e + + if verbose: + print(f"Passed: {vars(config)}") + return None + + num_threads = min(max_threads, len(test_inputs)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + future_tasks = [executor.submit(check_parity_with_config, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(future_tasks): + result = future.result() + if result is not None: + return result + + return None + + # Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. comprehensive_mode = False class TestMultiHeadAttention(unittest.TestCase): - # TODO: enable tests on CUDAExecutionProvider after fixing the issue. - # @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) - # def test_mha_cuda(self, config): - # parity_check_mha(config) + @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) + def test_mha_cuda(self, config): + parity_check_mha(config) @parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cpu(self, config): parity_check_mha(config) + def run_mha_cuda_multi_threading(self, spda_kernel): + for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): + test_inputs = [] + for config in configs: + ort_inputs = config.random_inputs() + + # Create reference inputs + old_format = config.input_format + config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH + ref_inputs = config.random_inputs() + config.input_format = old_format + test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs}) + + exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs)) + assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}" + + def test_mha_cuda_multi_threading(self): + self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) + + def test_mha_cuda_multi_threading_efficient(self): + self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + + def test_mha_cuda_multi_threading_trt(self): + sm = get_compute_capability() + if sm in [75, 80, 86, 89]: + self.run_mha_cuda_multi_threading( + SdpaKernel.TRT_FUSED_ATTENTION + | SdpaKernel.TRT_FLASH_ATTENTION + | SdpaKernel.TRT_CROSS_ATTENTION + | SdpaKernel.TRT_CAUSAL_ATTENTION + ) + if __name__ == "__main__": with torch.no_grad():