From 6ffaaebb60cd43cf7749e67a9bb54c3bd2cc4efd Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jul 2024 13:58:54 -0700 Subject: [PATCH] [CUDA] Attention kernel provider option (#21344) ### Description * Add a cuda provider option `sdpa_kernel` to choose which attention kernel to run for testing purpose. * Allow dump which attention kernel is used per node. * Reserve a flag for cudnn flash attention which will be added soon. #### CUDA provider option sdpa_kernel Instead of setting environment variable, we also support setting it in provider option. Note that the setting is global per session. That could help performance testing of each kernel. #### Attention Kernel Debug Info Set an environment variable `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`, and ORT will print sdpa kernel used in each node: For example ``` ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 ./onnxruntime_test_all --gtest_filter=MultiHeadAttentionTest* ``` It will show debug information of kernel used in testing: ``` [ RUN ] MultiHeadAttentionTest.SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=0 TRT_FUSED_ATTENTION=1 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=1 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1 Operator=MultiHeadAttention Node=node1 DataType=fp16 TRT_FUSED_ATTENTION=1 AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=1 TRT_FUSED_ATTENTION=0 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=0 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1 Operator=MultiHeadAttention Node=node1 DataType=fp16 EFFICIENT_ATTENTION=1 ``` In this test case, the debug info shows that one session uses trt fused attention and another session use efficient attention. --- cmake/onnxruntime_rocm_hipify.cmake | 2 + cmake/onnxruntime_unittests.cmake | 3 +- .../providers/cuda/cuda_provider_options.h | 1 + .../contrib_ops/cpu/bert/attention_common.h | 28 ++- .../contrib_ops/cuda/bert/attention.cc | 53 ++--- onnxruntime/contrib_ops/cuda/bert/attention.h | 4 +- .../cuda/bert/attention_kernel_options.cc | 166 +++++++++++++ .../cuda/bert/attention_kernel_options.h | 67 ++++++ .../cuda/bert/group_query_attention.cc | 30 +-- .../cuda/bert/group_query_attention.h | 2 + .../cuda/bert/multihead_attention.cc | 50 ++-- .../cuda/bert/multihead_attention.h | 3 +- .../contrib_ops/cuda/bert/packed_attention.cc | 33 ++- .../contrib_ops/cuda/bert/packed_attention.h | 9 +- .../cuda/bert/packed_multihead_attention.cc | 40 ++-- .../cuda/bert/packed_multihead_attention.h | 4 +- .../providers/cuda/cuda_execution_provider.h | 17 ++ .../cuda/cuda_execution_provider_info.cc | 4 + .../cuda/cuda_execution_provider_info.h | 4 + onnxruntime/core/providers/cuda/cuda_kernel.h | 6 + .../providers/cuda/cuda_provider_factory.cc | 2 + .../multihead_attention_op_test.cc | 4 +- .../attention_kernel_options_test.cc | 221 ++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 2 + 24 files changed, 645 insertions(+), 110 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2966a4624a96..a8c876d30873 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -15,6 +15,8 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/attention_softmax.cu" "bert/attention_prepare_qkv.cu" + "bert/attention_kernel_options.h" + "bert/attention_kernel_options.cc" "bert/decoder_attention_impl.h" "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0159c35d1941..38ed0b164019 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common) target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common) if (MSVC) # Cutlass code has an issue with the following: # warning C4100: 'magic': unreferenced formal parameter diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 6d53760ab60b..01a14de699dc 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 { int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not int use_tf32 = 1; // use TF32 + int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index a5b9c84c63eb..55292b35e1e3 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_ } // namespace sparse_attention namespace attention { + +enum class AttentionBackend : int { + FLASH_ATTENTION = 1, + EFFICIENT_ATTENTION = 2, + TRT_FUSED_ATTENTION = 4, + CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention. + MATH = 16, // unfused kernel cannot be disabled right now. + + // The following kernels might be deprecated in the future. + TRT_FLASH_ATTENTION = 32, + TRT_CROSS_ATTENTION = 64, + TRT_CAUSAL_ATTENTION = 128, +}; + +// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled). +constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"; + // Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; @@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT // Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; +// Environment variable to enable or disable cuDNN flash attention. +constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION"; + // Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; @@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; -// Minimum sequence length to enable memory efficient attention in FP32. -constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; +// Minimum sequence length to perfer memory efficient attention when data type is float32 +constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32"; + +// Default value for minimum sequence length to enable memory efficient attention in FP32. +constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256; // Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; + // Default value for the above setting. constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index d9907f09121d..cacd65313ebc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -3,7 +3,6 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention.h" #include "contrib_ops/cuda/bert/bert_padding.h" @@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_self_attention_ = - sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + kernel_options_ = this->GetAttentionKernelOptions(); - enable_trt_flash_attention_ = - sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); - enable_fused_causal_attention_ = - sizeof(T) == 2 && - ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = - ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention(); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = - sizeof(T) != 2 || - onnxruntime::ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); + + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); } template @@ -134,7 +114,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. - if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } // Allocate buffers @@ -220,7 +200,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == past && nullptr == present && (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); if (use_memory_efficient_attention) { @@ -231,6 +211,20 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + + debug_info.Print("Attention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + cublasHandle_t cublas = GetCublasHandle(context); typedef typename ToCudaType::MappedType CudaT; @@ -268,7 +262,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_fused_cross_attention, use_memory_efficient_attention); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); - ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index acafb379d713..0c7d3621f95e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -8,6 +8,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase { bool enable_trt_flash_attention_; bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; + + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc new file mode 100644 index 000000000000..28a095e68131 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" + +using namespace onnxruntime::contrib::attention; + +namespace onnxruntime { +void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { + if (value > 0) { + use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; + use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; + use_trt_fused_attention_ = (value & static_cast(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; + use_cudnn_flash_attention_ = (value & static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; + use_unfused_ = (value & static_cast(AttentionBackend::MATH)) > 0; + use_trt_flash_attention_ = (value & static_cast(AttentionBackend::TRT_FLASH_ATTENTION)) > 0; + use_trt_cross_attention_ = (value & static_cast(AttentionBackend::TRT_CROSS_ATTENTION)) > 0; + use_trt_causal_attention_ = (value & static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; + } else { + use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); + use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); + use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); + use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); + use_unfused_ = true; + use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableTrtFlashAttention, false); + use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedCrossAttention, false); + use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault(kEnableFusedCausalAttention, false); + } + + enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault(kEnableAttentionKernelDebugInfo, false); + + // When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing. + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + kMinSeqLenForFlashAttentionPackedQKV, + value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV); + + min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault( + kMinSeqLenForEfficientAttentionFp32, + value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32); + + if (use_build_flag) { + // Some kernels can be disabled at build time. If they are disabled, we should not use them. +#ifndef USE_FLASH_ATTENTION + use_flash_attention_ = false; +#endif + +#ifndef USE_MEMORY_EFFICIENT_ATTENTION + use_efficient_attention_ = false; +#endif + } +} + +void AttentionKernelOptions::InitializeOnce( + int sdpa_kernel, bool use_build_flag) { + std::call_once(this->initialize_once_flag_, [&]() { + this->Initialize(sdpa_kernel, use_build_flag); + if (this->enable_kernel_debug_info_) { + this->Print(); + } + }); +} + +void AttentionKernelOptions::Print() const { + std::stringstream sstream; + sstream << "AttentionKernelOptions:"; + sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); + sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); + sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); + sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); + sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_); + sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_); + sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_); + sstream << " MATH=" << int(use_unfused_); + + if (!use_unfused_) { + sstream << std::endl + << "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored."; + } + + // Output text in Cyan color to make it easier to spot + std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; +} + +// Classify the kernel used in TRT fused runner. +void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) { + if (causal) { + use_trt_causal_attention = true; + } else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) { + use_trt_flash_attention = true; + } else { + use_trt_fused_attention = true; + } +} + +void AttentionKernelDebugInfo::Print(const char* operator_name, + const std::string& node_name, + bool is_float16, + bool is_bfloat16) const { + std::stringstream sstream; + sstream << "Operator=" << operator_name; + + if (node_name.length() > 0) { + sstream << " Node=" << node_name; + } + + if (is_bfloat16) { + sstream << " DataType=bf16"; + } else if (is_float16) { + sstream << " DataType=fp16"; + } else { + sstream << " DataType=fp32"; + } + + if (use_flash_attention.has_value() && use_flash_attention.value()) { + sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value()); + } + + if (use_efficient_attention.has_value() && use_efficient_attention.value()) { + sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value()); + } + + if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { + sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value()); + } + + if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) { + sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value()); + } + + if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) { + sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value()); + } + + if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) { + sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value()); + } + + if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) { + sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value()); + } + + bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) || + (use_efficient_attention.has_value() && use_efficient_attention.value()) || + (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) || + (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) || + (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) || + (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) || + (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()); + + // Fall back to unfused when no fused kernel is enabled. + if (!use_fused) { + sstream << " MATH=1"; + } + + // Output text in Cyan color to make it easier to spot. + std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; +} + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h new file mode 100644 index 000000000000..bd7df5f490c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include + +namespace onnxruntime { +struct AttentionKernelDebugInfo { + std::optional use_flash_attention = std::nullopt; + std::optional use_efficient_attention = std::nullopt; + std::optional use_trt_fused_attention = std::nullopt; + std::optional use_cudnn_flash_attention = std::nullopt; + std::optional use_trt_flash_attention = std::nullopt; + std::optional use_trt_cross_attention = std::nullopt; + std::optional use_trt_causal_attention = std::nullopt; + void SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length); + void Print(const char* operator_name, const std::string& node_name, bool is_float16, bool is_bfloat16) const; +}; + +class AttentionKernelOptions { + public: + void InitializeOnce(int sdpa_kernel, bool use_build_flag); + + bool UseFlashAttention() const { return use_flash_attention_; } + bool UseEfficientAttention() const { return use_efficient_attention_; } + bool UseTrtFusedAttention() const { return use_trt_fused_attention_; } + bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; } + bool UseUnfusedAttention() const { return use_unfused_; } + bool UseTrtFlashAttention() const { return use_trt_flash_attention_; } + bool UseTrtCrossAttention() const { return use_trt_cross_attention_; } + bool UseTrtCausalAttention() const { return use_trt_causal_attention_; } + + bool AllowDebugInfo() const { return enable_kernel_debug_info_; } + + int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; } + int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; } + + protected: + void Print() const; + + void Initialize(int value, bool use_build_flag); + + private: + bool use_flash_attention_{true}; + bool use_efficient_attention_{true}; + bool use_trt_fused_attention_{true}; + bool use_cudnn_flash_attention_{false}; + bool use_unfused_{true}; + + bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; + + // Causal attention is disabled by default in #14732. + bool use_trt_causal_attention_{false}; + + bool enable_kernel_debug_info_{false}; + + int min_seq_len_for_flash_attention_packed_qkv_{0}; + + int min_seq_len_for_efficient_attention_fp32_{0}; + + std::once_flag initialize_once_flag_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3b6ad238cc82..797f9b0a1ea4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -52,20 +52,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); -#else - disable_flash_attention_ = true; -#endif + kernel_options_ = this->GetAttentionKernelOptions(); + + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION // Memory efficient attention only supports float and float16, not bfloat16. - disable_memory_efficient_attention_ = std::is_same::value || - ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = std::is_same::value || !kernel_options_->UseEfficientAttention(); + if (!disable_flash_attention_) { zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); } @@ -161,7 +154,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size); if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -201,6 +194,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + + debug_info.Print("GroupQueryAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + // seqlens_k buffer size_t seqlens_k_bytes = 0; seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 15573ece166f..4ff5b0a59f02 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -6,6 +6,7 @@ #include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -32,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel { bool disable_memory_efficient_attention_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ba8b00df07e0..b96140f3897f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/providers/cuda/cuda_common.h" -#include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" @@ -47,31 +46,16 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); - disable_fused_self_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + kernel_options_ = this->GetAttentionKernelOptions(); - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); - disable_fused_cross_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); // Allocate cache buffers constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); @@ -155,7 +139,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. if (use_flash_attention && key == nullptr && value == nullptr && - parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } // Allocate buffers @@ -229,9 +213,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } #if USE_MEMORY_EFFICIENT_ATTENTION + int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 - parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 || - parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32; + parameters.sequence_length >= length_threshold || + parameters.kv_sequence_length >= length_threshold; bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; @@ -249,6 +234,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. bool no_qkv_workspace = nullptr == value && diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 86a32c92ce00..26e38dbad9fd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -8,6 +8,7 @@ #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -31,12 +32,12 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 0146cce30c7d..a1149ddbf99f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -33,12 +33,11 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) template -TrtFusedAttention::TrtFusedAttention() { - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); - - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); +TrtFusedAttention::TrtFusedAttention(const OpKernelInfo& info) + : CudaKernel(info) { + kernel_options_ = this->GetAttentionKernelOptions(); + disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); } template @@ -86,7 +85,8 @@ template class TrtFusedAttention; template class TrtFusedAttention; template -PackedAttention::PackedAttention(const OpKernelInfo& info) : TrtFusedAttention(), CudaKernel(info) { +PackedAttention::PackedAttention(const OpKernelInfo& info) + : TrtFusedAttention(info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); @@ -268,7 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; - parameters.use_tf32 = UseTF32(); + parameters.use_tf32 = this->UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -295,6 +295,19 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { } #endif + if (this->kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length); + } + + debug_info.Print("PackedAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + typedef typename ToCudaType::MappedType CudaT; CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -313,7 +326,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, this->UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias @@ -341,7 +354,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_memory_efficient_attention = use_memory_efficient_attention; - return QkvToContext(device_prop, cublas, Stream(context), parameters, data); + return QkvToContext(device_prop, cublas, this->Stream(context), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index f00c112fc73d..67b420764169 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -9,6 +9,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -17,14 +18,16 @@ namespace cuda { using namespace onnxruntime::cuda; template -class TrtFusedAttention { +class TrtFusedAttention : public CudaKernel { public: - TrtFusedAttention(); + TrtFusedAttention(const OpKernelInfo& info); protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; protected: + const AttentionKernelOptions* kernel_options_; + bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; @@ -32,7 +35,7 @@ class TrtFusedAttention { }; template -class PackedAttention final : public TrtFusedAttention, public CudaKernel { +class PackedAttention final : public TrtFusedAttention { public: PackedAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 3fbbafc01254..53e96fc732a3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -35,30 +35,16 @@ REGISTER_KERNEL_TYPED(MLFloat16) template PackedMultiHeadAttention::PackedMultiHeadAttention(const OpKernelInfo& info) - : TrtFusedAttention(), CudaKernel(info) { + : TrtFusedAttention(info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); scale_ = info.GetAttrOrDefault("scale", 0.0f); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault( - attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_flash_attention_ = sizeof(T) != 2 || !this->kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault( - attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = !this->kernel_options_->UseEfficientAttention(); } template @@ -228,7 +214,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; - parameters.use_tf32 = UseTF32(); + parameters.use_tf32 = this->UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, @@ -255,7 +241,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. if (use_flash_attention && key == nullptr && value == nullptr && - parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + parameters.sequence_length < this->kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } } @@ -271,11 +257,25 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = is_good_for_rpb && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } #endif + if (this->kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length); + } + + debug_info.Print("PackedMultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + typedef typename ToCudaType::MappedType CudaT; cublasHandle_t cublas = this->GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index e30c603dc30a..9b52a70fc618 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -4,13 +4,14 @@ #pragma once #include "contrib_ops/cuda/bert/packed_attention.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { namespace cuda { template -class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaKernel { +class PackedMultiHeadAttention final : public TrtFusedAttention { public: PackedMultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; @@ -32,7 +33,6 @@ class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaK bool disable_memory_efficient_attention_; bool disable_flash_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index f53779058a8a..9c8a8712ca51 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -17,6 +17,10 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/tunable/cuda_tuning_context.h" +#ifndef DISABLE_CONTRIB_OPS +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#endif + namespace onnxruntime { void RunOnUnload(std::function function); @@ -80,6 +84,14 @@ class CUDAExecutionProvider : public IExecutionProvider { bool IsNHWCPreferred() const { return info_.prefer_nhwc; } bool UseTF32() const { return info_.use_tf32; } +#ifndef DISABLE_CONTRIB_OPS + // Attention kernel options parsed from sdpa_kernel cuda provider option. + const AttentionKernelOptions* GetAttentionKernelOptions() const { + attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true); + return &attention_kernel_options_; + } +#endif + ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); } @@ -110,6 +122,11 @@ class CUDAExecutionProvider : public IExecutionProvider { // the tuning context might be altered when calling into a TunableOp mutable cuda::tunable::CudaTuningContext tuning_context_; +#ifndef DISABLE_CONTRIB_OPS + // Attention kernel options parsed from sdpa_kernel cuda provider option. + mutable AttentionKernelOptions attention_kernel_options_; +#endif + class PerThreadContext final { public: PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index c96381e3e68b..31cf991a34fc 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; constexpr const char* kUseTF32 = "use_tf32"; +constexpr const char* kSdpaKernel = "sdpa_kernel"; } // namespace provider_option_names } // namespace cuda @@ -117,6 +118,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) + .AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -170,6 +172,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; return options; @@ -192,6 +195,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 1cac3d151369..0efad80f743d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -79,6 +79,8 @@ struct CUDAExecutionProviderInfo { // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. bool use_tf32{true}; + int sdpa_kernel{0}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -91,6 +93,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each) + // Do not exceed 32 bits here otherwise some bits will be lost in x86. size_t data = static_cast(info.device_id) ^ (static_cast(info.arena_extend_strategy) << 16) ^ (static_cast(info.cudnn_conv_algo_search) << 18) ^ @@ -109,6 +112,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { onnxruntime::HashCombine(info.gpu_mem_limit, value); onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + onnxruntime::HashCombine(info.sdpa_kernel, value); // Memory pointers onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 288da23f35ec..9d37a9775872 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -94,6 +94,12 @@ class CudaKernel : public OpKernel { return provider_->UseTF32(); } +#ifndef DISABLE_CONTRIB_OPS + const AttentionKernelOptions* GetAttentionKernelOptions() const { + return provider_->GetAttentionKernelOptions(); + } +#endif + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 7851da7fa91a..b1d54e56ded4 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -226,6 +226,7 @@ struct CUDA_Provider : Provider { info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; info.use_tf32 = params->use_tf32 != 0; + info.sdpa_kernel = params->sdpa_kernel; return std::make_shared(info); } @@ -260,6 +261,7 @@ struct CUDA_Provider : Provider { cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; cuda_options.use_tf32 = internal_options.use_tf32; + cuda_options.sdpa_kernel = internal_options.sdpa_kernel; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index a61e917b41e5..f0255d7ece84 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -394,8 +394,8 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu } #if USE_MEMORY_EFFICIENT_ATTENTION - if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 || - data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) { + if (data.sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32 || + data.kv_sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( diff --git a/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc new file mode 100644 index 000000000000..b2e986f68076 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef DISABLE_CONTRIB_OPS + +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "test/util/include/scoped_env_vars.h" +#include "gtest/gtest.h" + +#include +#include + +using onnxruntime::AttentionKernelOptions; +using onnxruntime::contrib::attention::AttentionBackend; + +namespace onnxruntime { +namespace test { + +TEST(AttentionKernelOptionsTest, NonZeroValue) { + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION) | static_cast(AttentionBackend::EFFICIENT_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_TRUE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_FUSED_ATTENTION) | static_cast(AttentionBackend::MATH); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_TRUE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_TRUE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_TRUE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_CROSS_ATTENTION) | static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_TRUE(options.UseTrtCrossAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + // Test environment variables are ignored when option value is non-zero + // Test default min sequence lengths are zeros + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}}}; + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + // Test min sequence lengths can be parsed from environment variables when option value is non-zero + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"}, + {onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}}; + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256); + } +} + +// Test all environment variables take effect when option value is 0. +TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) { + constexpr int value = 0; + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"}, + {onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}}; + AttentionKernelOptions options; + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_TRUE(options.UseEfficientAttention()); + ASSERT_TRUE(options.UseTrtFusedAttention()); + ASSERT_TRUE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_TRUE(options.UseTrtFlashAttention()); + ASSERT_TRUE(options.UseTrtCrossAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256); +} + +// Test default min sequence lengths when environment variables are not set. +TEST(AttentionKernelOptionsTest, DefaultMinSeqLens) { + constexpr int value = 0; + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}}}; + AttentionKernelOptions options; + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), + onnxruntime::contrib::attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), + onnxruntime::contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32); +} + +} // namespace test +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e4814aa7fc03..892e7de8bb6e 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -446,6 +446,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("use_tf32", ["1", "0"]) + test_get_and_set_option_with_values("sdpa_kernel", ["0", "1", "2"]) + option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" option["gpu_external_empty_cache"] = "0"