Skip to content

Commit

Permalink
[CUDA] Attention kernel provider option (#21344)
Browse files Browse the repository at this point in the history
### Description
* Add a cuda provider option `sdpa_kernel` to choose which attention kernel to run for testing purpose. 
* Allow dump which attention kernel is used per node.
* Reserve  a flag for cudnn flash attention which will be added soon.

#### CUDA provider option sdpa_kernel
Instead of setting environment variable, we also support setting it in
provider option. Note that the setting is global per session. That could
help performance testing of each kernel.

#### Attention Kernel Debug Info
Set an environment variable `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`,
and ORT will print sdpa kernel used in each node:

For example 
```
ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 ./onnxruntime_test_all --gtest_filter=MultiHeadAttentionTest*
```
It will show debug information of kernel used in testing:
```
[ RUN      ] MultiHeadAttentionTest.SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV
AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=0 TRT_FUSED_ATTENTION=1 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=1 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1
Operator=MultiHeadAttention Node=node1 DataType=fp16 TRT_FUSED_ATTENTION=1
AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=1 TRT_FUSED_ATTENTION=0 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=0 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1
Operator=MultiHeadAttention Node=node1 DataType=fp16 EFFICIENT_ATTENTION=1
```
In this test case, the debug info shows that one session uses trt fused
attention and another session use efficient attention.
  • Loading branch information
tianleiwu authored Jul 19, 2024
1 parent 01df8c7 commit 6ffaaeb
Show file tree
Hide file tree
Showing 24 changed files with 645 additions and 110 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/attention_kernel_options.h"
"bert/attention_kernel_options.cc"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 {
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
};
28 changes: 26 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_
} // namespace sparse_attention

namespace attention {

enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
MATH = 16, // unfused kernel cannot be disabled right now.

// The following kernels might be deprecated in the future.
TRT_FLASH_ATTENTION = 32,
TRT_CROSS_ATTENTION = 64,
TRT_CAUSAL_ATTENTION = 128,
};

// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO";

// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

Expand All @@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";

// Environment variable to enable or disable cuDNN flash attention.
constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION";

// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";

Expand All @@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;
// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";

// Default value for minimum sequence length to enable memory efficient attention in FP32.
constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256;

// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";

// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;

Expand Down
53 changes: 23 additions & 30 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
Expand Down Expand Up @@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
disable_fused_self_attention_ =
sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
kernel_options_ = this->GetAttentionKernelOptions();

enable_trt_flash_attention_ =
sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();

enable_fused_causal_attention_ =
sizeof(T) == 2 &&
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ =
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();

#if USE_FLASH_ATTENTION
disable_flash_attention_ =
sizeof(T) != 2 ||
onnxruntime::ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();

disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
}

template <typename T>
Expand Down Expand Up @@ -134,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_heads,
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
// Allocate buffers
Expand Down Expand Up @@ -220,7 +200,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == past &&
nullptr == present &&
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);

if (use_memory_efficient_attention) {
Expand All @@ -231,6 +211,20 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif

if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length);
}

debug_info.Print("Attention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}

cublasHandle_t cublas = GetCublasHandle(context);

typedef typename ToCudaType<T>::MappedType CudaT;
Expand Down Expand Up @@ -268,7 +262,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_fused_cross_attention,
use_memory_efficient_attention);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
;

typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase {
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;

const AttentionKernelOptions* kernel_options_;
};

} // namespace cuda
Expand Down
166 changes: 166 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#include <iomanip>
#include <iostream>
#include <sstream>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"

using namespace onnxruntime::contrib::attention;

namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0;
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0;
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false);
}

enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault<bool>(kEnableAttentionKernelDebugInfo, false);

// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing.
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForFlashAttentionPackedQKV,
value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV);

min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);

if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
use_flash_attention_ = false;
#endif

#ifndef USE_MEMORY_EFFICIENT_ATTENTION
use_efficient_attention_ = false;
#endif
}
}

void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
if (this->enable_kernel_debug_info_) {
this->Print();
}
});
}

void AttentionKernelOptions::Print() const {
std::stringstream sstream;
sstream << "AttentionKernelOptions:";
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_);
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_);
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_);
sstream << " MATH=" << int(use_unfused_);

if (!use_unfused_) {
sstream << std::endl
<< "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored.";
}

// Output text in Cyan color to make it easier to spot
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}

// Classify the kernel used in TRT fused runner.
void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) {
if (causal) {
use_trt_causal_attention = true;
} else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) {
use_trt_flash_attention = true;
} else {
use_trt_fused_attention = true;
}
}

void AttentionKernelDebugInfo::Print(const char* operator_name,
const std::string& node_name,
bool is_float16,
bool is_bfloat16) const {
std::stringstream sstream;
sstream << "Operator=" << operator_name;

if (node_name.length() > 0) {
sstream << " Node=" << node_name;
}

if (is_bfloat16) {
sstream << " DataType=bf16";
} else if (is_float16) {
sstream << " DataType=fp16";
} else {
sstream << " DataType=fp32";
}

if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
}

if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
}

if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
}

if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
}

if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
}

if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
}

if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
}

bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());

// Fall back to unfused when no fused kernel is enabled.
if (!use_fused) {
sstream << " MATH=1";
}

// Output text in Cyan color to make it easier to spot.
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}

} // namespace onnxruntime
Loading

0 comments on commit 6ffaaeb

Please sign in to comment.