From 10cd02799c2142eaf3bd881fef3c7ad5a1eab8f7 Mon Sep 17 00:00:00 2001 From: mindest Date: Sun, 29 Sep 2024 14:51:02 +0800 Subject: [PATCH] DecoderMaskedMultiHeadAttention CPU kernel. --- docs/ContribOperators.md | 10 +- .../contrib_ops/cpu/bert/attention_common.h | 39 ++ .../contrib_ops/cpu/bert/attention_cpu_base.h | 50 +- .../decoder_masked_multihead_attention.cc | 479 ++++++++++++++++++ .../bert/decoder_masked_multihead_attention.h | 72 +++ .../decoder_masked_multihead_attention_impl.h | 39 -- .../core/graph/contrib_ops/bert_defs.cc | 14 +- 7 files changed, 632 insertions(+), 71 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc create mode 100644 onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 09a7e47fc9913..0b966c813746e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1175,9 +1175,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.
beam_width (optional) : M
-
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+
The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.
cache_indirection (optional) : M
-
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
@@ -1192,7 +1192,7 @@ This version of the operator has been available since version 1 of the 'com.micr
present_value (optional) : T
present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
qk (optional) : V
-
normalized Q * K, of shape (batch_size, num_heads, 1, head_size).
+
normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length).
#### Type Constraints @@ -1261,9 +1261,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_sequence_length : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
beam_width (optional) : M
-
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+
The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.
cache_indirection (optional) : M
-
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration
#### Outputs diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index e0fa581c8071d..46638555576a9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -79,6 +79,45 @@ struct AttentionParameters { AttentionQkvFormat qkv_format; }; +struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { + int beam_width = 1; + + // Only NeoX style rotary embedding is supported + int rotary_embedding_dim = 0; + int t_step = 0; + + // Whether to use multihead attention(excludes matmul and bias) + bool is_mha = false; + bool is_cross_attention = false; + bool is_packed_qkv = false; + + // Useful to better use global memory bandwidth on certain CUDA architectures. + // Turned off by default for now until we fully understand performance implications + // for all types of workloads. + // Can be turned on by appropriate environment variable (see attention_common.h). + bool kv_data_in_flight = false; + + void* q = nullptr; + void* q_bias = nullptr; + + void* k = nullptr; + void* k_bias = nullptr; + + void* v = nullptr; + void* v_bias = nullptr; + + void* attention_bias = nullptr; + + void* k_cache = nullptr; + void* v_cache = nullptr; + + void* out = nullptr; + void* out_qk = nullptr; + + const int32_t* cache_indir = nullptr; + const int32_t* mask = nullptr; // [B, total_sequence_length] +}; + // Parameters deduced from node attributes and inputs/outputs. struct PackedAttentionParameters { int batch_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index ae2eaf0204026..55e0c143d019e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -36,7 +36,9 @@ class AttentionCPUBase : public AttentionBase { int v_head_size, // head size of V (H_v) int v_hidden_size, // hidden size of V (D_v) const Tensor* attn_bias, // additive bias applied on scaled QK. - OpKernelContext* context) const { + OpKernelContext* context, + Tensor* scaled_qk = nullptr // output buffer for QK (if needed) + ) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -71,7 +73,7 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f). - // Merge padding mask with causual mask, and broadcast to 3D (BxSxT). + // Merge padding mask with causal mask, and broadcast to 3D (BxSxT). PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data), causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length); @@ -85,6 +87,7 @@ class AttentionCPUBase : public AttentionBase { T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr; const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + T* scaled_qk_data = scaled_qk != nullptr ? scaled_qk->MutableData() : nullptr; const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data() : nullptr; auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{}; @@ -97,7 +100,7 @@ class AttentionCPUBase : public AttentionBase { static_cast(mask_data), batch_size, sequence_length, kv_sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims); + present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims, scaled_qk_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -117,23 +120,24 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - T* mask_data, // buffer for mask data. - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int kv_sequence_length, // sequence length of cross-attention (L) - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - const T* past_key, // past key only (if not using past state) - T* present, // present state - T* present_key, // present key only (if not using present state) - ThreadPool* tp, // thread pool - float scale, // scale factor - const T* attn_bias_data, // attention bias - gsl::span attn_bias_dims // attention bias shape + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + T* mask_data, // buffer for mask data. + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention (S) + int kv_sequence_length, // sequence length of cross-attention (L) + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + const T* past_key, // past key only (if not using past state) + T* present, // present state + T* present_key, // present key only (if not using present state) + ThreadPool* tp, // thread pool + float scale, // scale factor + const T* attn_bias_data, // attention bias + gsl::span attn_bias_dims, // attention bias shape + T* scaled_qk_data = nullptr // scaled output QK buffer ) const { const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H @@ -230,6 +234,12 @@ class AttentionCPUBase : public AttentionBase { }); } + if (scaled_qk_data != nullptr) { + // Output the scaled Q*K^T if needed. + memcpy(scaled_qk_data, attention_probs, + SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T)); + } + DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length); // attention_probs(B, N, S, T) = Softmax(attention_probs) diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc new file mode 100644 index 0000000000000..e41efe3b212db --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc @@ -0,0 +1,479 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention_cpu_base.h" +#include "attention_utils.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/cpu/bert/decoder_masked_multihead_attention.h" + +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { + +// TODO: refactor +static constexpr int kPastSequenceLengthInputIndex = 7; +static constexpr int kBeamWidthInputIndex = 8; +static constexpr int kCacheIndirectionInputIndex = 9; +static constexpr int kPastInputIndex = 5; +static constexpr int kPresentOutputIndex = 1; +static constexpr int kQKOutputIndex = 3; +static constexpr int kBiasIndex = 10; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DecoderMaskedMultiHeadAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ + .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ + DecoderMaskedMultiHeadAttention); + +REGISTER_KERNEL_TYPED(float) + +template +DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const OpKernelInfo& info) + : OpKernel(info), AttentionCPUBase(info, false) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + output_qk_ = info.GetAttrOrDefault("output_qk", 0LL); +} + +template +Status DecoderMaskedMultiHeadAttention::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* attention_bias = context->Input(4); + const Tensor* past_key = context->Input(kPastInputIndex); + const Tensor* past_value = context->Input(kPastInputIndex + 1); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + const Tensor* beam_width = context->Input(kBeamWidthInputIndex); + const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); + const Tensor* bias = context->Input(kBiasIndex); + + DecoderMaskedMultiHeadAttentionParams parameters; + + parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( + attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + + bool is_unidirectional = false; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, + key, + value, + bias, + mask_index, + attention_bias, + past_key, + past_value, + past_seq_len, + ¶meters, + num_heads_, + mask_filter_value_, + scale_, + is_unidirectional, + past_present_share_buffer_, + kDecoderMaskedMultiHeadAttention)); + + int batch_size = parameters.batch_size; + int sequence_length = parameters.sequence_length; + int head_size = parameters.head_size; + int v_head_size = parameters.v_head_size; + int hidden_size = parameters.hidden_size; + int v_hidden_size = parameters.v_hidden_size; + + // This kernel is for decoding only (i.e.) sequence length has to be 1 + if (sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention. " + "Actual length is ", + sequence_length); + } + + if (head_size != v_head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "QK head size should be same as V head size to use DecoderMaskedMultiHeadAttention"); + } + + if (parameters.mask_type != AttentionMaskType::MASK_2D_KEY_PADDING && + parameters.mask_type != AttentionMaskType::MASK_NONE) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "DecoderMaskedMultiHeadAttention only supports no mask or 2D key " + "padding mask of shape [batch, total_seq_length] currently"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + parameters.batch_size, parameters.num_heads, + past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length, + head_size}; + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); + Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); + Tensor* cross_qk = nullptr; + + parameters.is_mha = true; + + // Update the q buffers + parameters.q = const_cast(query->Data()); + + // Update the attention bias for self attention + if (attention_bias != nullptr) { + parameters.attention_bias = const_cast(attention_bias->Data()); + } + + // Decoder cross-attention + if (past_key == nullptr && present_key == nullptr) { + if (attention_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "DecoderMaskedMultiHeadAttention does not support attention bias for cross-attention"); + } + + parameters.is_cross_attention = true; + parameters.total_sequence_length = parameters.kv_sequence_length; + parameters.max_sequence_length = parameters.kv_sequence_length; + // parameters.k and parameters.v are nullptr + parameters.k_cache = const_cast(key->Data()); + parameters.v_cache = const_cast(value->Data()); + parameters.k_bias = nullptr; + parameters.v_bias = nullptr; + + } else { + // Sanity check + ORT_ENFORCE(past_present_share_buffer_); + ORT_ENFORCE(past_key != nullptr && past_value != nullptr); + + auto* present_key_data = present_key->MutableData(); + auto* present_value_data = present_value->MutableData(); + auto* past_key_data = past_key->Data(); + auto* past_value_data = past_value->Data(); + + if (present_key_data != past_key_data) { + std::memcpy(present_key_data, past_key_data, past_key->SizeInBytes()); + } + if (present_value_data != past_value_data) { + std::memcpy(present_value_data, past_value_data, past_value->SizeInBytes()); + } + + parameters.is_cross_attention = false; + + bool is_packed_qkv = (key == nullptr && value == nullptr); + parameters.is_packed_qkv = is_packed_qkv; + + parameters.k = is_packed_qkv + ? const_cast(query->Data() + hidden_size) + : const_cast(key->Data()); + parameters.v = is_packed_qkv + ? const_cast(query->Data() + 2 * static_cast(hidden_size)) + : const_cast(value->Data()); + parameters.k_cache = present_key_data; + parameters.v_cache = present_value_data; + } + + if (output_qk_) { + int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length}; + TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0])); + cross_qk = context->Output(kQKOutputIndex, qk_shape); + parameters.out_qk = cross_qk->MutableData(); + } + + parameters.out = output->MutableDataRaw(); + + // Beam width (in case we are using this op inside BeamSearch) + if (beam_width != nullptr) { + parameters.beam_width = static_cast(*beam_width->Data()); + } + + // Cache indirection (in case we are using this op inside BeamSearch) + if (parameters.beam_width > 1 && cache_indir == nullptr) { + // If beam width > 1, then cache indirection buffer MUST be present + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "If beam width is greater than 1, then cache indirection buffer MUST be present"); + } + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + OrtValue Q; + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, 1, head_size, query, bias, 0, Q)); + + // Cross-attention case + if (parameters.is_cross_attention) { + return ApplyAttention(Q.GetMutable()->MutableData(), + key->Data(), + value->Data(), + mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value, + batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, + head_size, v_head_size, v_hidden_size, attention_bias, context, cross_qk); + } + + OrtValue K, V; + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, 1, head_size, key, bias, hidden_size, K)); + ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( + context, allocator, batch_size, num_heads_, 1, v_head_size, value, bias, 2 * hidden_size, V)); + + // Self-attention, !has_beams + if (parameters.cache_indir == nullptr) { + return ApplyAttention(Q.GetMutable()->MutableData(), + K.GetMutable()->MutableData(), + V.GetMutable()->MutableData(), + mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value, + batch_size, 1 /* sequence_length */, parameters.kv_sequence_length, + head_size, v_head_size, v_hidden_size, attention_bias, context, cross_qk); + } + + // Self-attention, has_beams + return ApplyAttentionWithBeams(Q.GetMutable()->MutableData(), + K.GetMutable()->MutableData(), + V.GetMutable()->MutableData(), + mask_index, past_key, past_value, output, present_key, present_value, + batch_size, parameters.past_sequence_length, parameters.max_sequence_length, + head_size, v_head_size, v_hidden_size, attention_bias, cache_indir, context, cross_qk); +} + +template +Status DecoderMaskedMultiHeadAttention::ApplyAttentionWithBeams( + const T* Q, + const T* K, + const T* V, + const Tensor* mask_index, + const Tensor* past_key, + const Tensor* past_value, + Tensor* output, + Tensor* present_key, + Tensor* present_value, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + int v_head_size, + int v_hidden_size, + const Tensor* attn_bias, + const Tensor* cache_indir, + OpKernelContext* context, + Tensor* scaled_qk) const { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto* tp = context->GetOperatorThreadPool(); + + int total_sequence_length = past_sequence_length + 1; + size_t bytes = SafeInt(batch_size) * num_heads_ * total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + + T* scaled_qk_data = (scaled_qk != nullptr) ? scaled_qk->MutableData() : nullptr; + + ComputeAttentionProbsWithBeams(static_cast(attention_probs), Q, K, mask_index->Data(), batch_size, + past_sequence_length, max_sequence_length, head_size, past_key->Data(), + present_key->MutableData(), tp, attn_bias->Data(), + cache_indir->Data(), scaled_qk_data); + + // Compute the attentionScore * Value: out_tmp(B, N, 1, H_v) = attention_probs(B, N, 1, T) x V(B, N, T, H_v) + auto out_tmp_data = allocator->Alloc(SafeInt(batch_size) * num_heads_ * v_head_size * sizeof(T)); + BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator))); + + ComputeVxAttentionScoreWithBeams(output->MutableData(), static_cast(out_tmp_data), + static_cast(attention_probs), V, batch_size, + past_sequence_length, max_sequence_length, v_head_size, past_value->Data(), + present_value->MutableData(), cache_indir->Data(), tp); + + return Status::OK(); +} + +template +void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( + T* attention_probs, + const T* Q, + const T* K, + const T* mask_index_data, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + const T* past_key_data, + T* present_key_data, + ThreadPool* tp, + const T* attn_bias_data, + const int32_t* cache_indir_data, + T* scaled_qk_data) const { + float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + TensorOpCost unit_cost; + auto total_sequence_length = past_sequence_length + 1; + const ptrdiff_t probs_matrix_size = total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); + + unit_cost.compute_cycles = static_cast((SafeInt(2) * head_size - 1) * total_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(2) * head_size * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(head_size * total_sequence_length * sizeof(T)); + + if (attn_bias_data != nullptr) { + unit_cost.bytes_loaded += probs_matrix_bytes * 2; + unit_cost.bytes_stored += probs_matrix_bytes; + } + + if (mask_index_data != nullptr) { + unit_cost.bytes_stored += probs_matrix_bytes; + } + + // Cost of appending current key to present key + unit_cost.compute_cycles += static_cast(head_size); + unit_cost.bytes_loaded += static_cast(head_size); + + // Parallel for loop + const int loop_len = batch_size * num_heads_; + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i) / num_heads_; + const T* q_vec = Q + i * head_size; + + { + // Calculate the latest position of the attention_probs + // b,n,s,h b,n,t,h b,n,s,t + auto last_offset = past_sequence_length + i * probs_matrix_size; + T* attention_probs_ptr = reinterpret_cast(attention_probs) + last_offset; + math::Dot(head_size, q_vec, + K + i * head_size, + attention_probs_ptr, + nullptr); + + // Apply the attention bias and mask + if (attn_bias_data != nullptr) { + *attention_probs_ptr += attn_bias_data[last_offset]; + } + bool is_masked = (mask_index_data != nullptr) && + (mask_index_data[(batch_index + 1) * total_sequence_length - 1] == 0); + if (is_masked) { + *attention_probs_ptr += mask_filter_value_; + } + *attention_probs_ptr *= scale; + } + + { + // Calculate the rest of the attention_probs + for (int j = 0; j < past_sequence_length; ++j) { + const int* beam_indices = &cache_indir_data[batch_index * max_sequence_length]; + const int beam_offset = beam_indices[i] * num_heads_ * max_sequence_length * head_size; + const T* past_k_vec = past_key_data + beam_offset; + T* output = reinterpret_cast(attention_probs) + j + i * probs_matrix_size; + math::Dot(head_size, q_vec, past_k_vec, output, nullptr); + // Apply the attention bias and mask + if (attn_bias_data != nullptr) { + const int attn_bias_beam_offset = beam_indices[i] * num_heads_ * probs_matrix_size; + *output += attn_bias_data[j + attn_bias_beam_offset]; + } + bool is_masked = (mask_index_data != nullptr) && + (mask_index_data[batch_index * total_sequence_length + j] == 0); + if (is_masked) { + *output += mask_filter_value_; + } + } + } + // Append current key to present key (past_present_share_buffer_ is true) + memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T)); + } + }); + + if (scaled_qk_data != nullptr) { + // Output the scaled Q*K^T if needed. + memcpy(scaled_qk_data, attention_probs, + SafeInt(batch_size) * num_heads_ * total_sequence_length * sizeof(T)); + } + + // attention_probs(B, N, 1, T) = Softmax(attention_probs) + { + const int N = batch_size * num_heads_; + const int D = total_sequence_length; + ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp); + } +} + +template +void DecoderMaskedMultiHeadAttention::ComputeVxAttentionScoreWithBeams( + T* output, + T* tmp_buffer, + const T* attention_probs, + const T* V, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int v_head_size, + const T* past_value_data, + T* present_value_data, + const int32_t* cache_indir_data, + ThreadPool* tp) const { + const int total_sequence_length = past_sequence_length + 1; + + TensorOpCost unit_cost; + unit_cost.compute_cycles = static_cast(SafeInt(2) * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(3) * v_head_size * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(SafeInt(2) * v_head_size * total_sequence_length * sizeof(T)); + + // Cost of appending current value to present value + unit_cost.compute_cycles += static_cast(v_head_size); + unit_cost.bytes_loaded += static_cast(v_head_size); + + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i / num_heads_); + + // Compute the attention score + { + const T* attn_probs_ptr = attention_probs + (i + 1) * total_sequence_length - 1; + math::Scale(v_head_size, + static_cast(*attn_probs_ptr), + V + i * v_head_size, + output + i * v_head_size, + nullptr); + } + { + for (int j = 0; j < past_sequence_length; ++j) { + const int* beam_indices = &cache_indir_data[batch_index * max_sequence_length]; + const int beam_offset = beam_indices[j] * num_heads_ * max_sequence_length * v_head_size; + const T* past_value_vec = past_value_data + beam_offset; + const T* attn_probs_ptr = attention_probs + j + i * total_sequence_length; + + math::Scale(v_head_size, + static_cast(*attn_probs_ptr), + past_value_vec + j * v_head_size, + tmp_buffer + i * v_head_size, + nullptr); + math::Add(v_head_size, + output + i * v_head_size, + tmp_buffer + i * v_head_size, + output + i * v_head_size, + nullptr); + } + } + // Append current value to present value (past_present_share_buffer_ is true) + memcpy(present_value_data + i * max_sequence_length * v_head_size, + V + i * v_head_size, + v_head_size * sizeof(T)); + } + }); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h new file mode 100644 index 0000000000000..5e6dfe29b5b1f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionCPUBase { + public: + DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); + Status ApplyAttentionWithBeams(const T* Q, + const T* K, + const T* V, + const Tensor* mask_index, + const Tensor* past_key, + const Tensor* past_value, + Tensor* output, + Tensor* present_key, + Tensor* present_value, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + int v_head_size, + int v_hidden_size, + const Tensor* attn_bias, + const Tensor* cache_indir, + OpKernelContext* context, + Tensor* scaled_qk = nullptr) const; + void ComputeAttentionProbsWithBeams(T* attention_probs, + const T* Q, + const T* K, + const T* mask_index_data, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int head_size, + const T* past_key, + T* present_key, + ThreadPool* tp, + const T* attn_bias_data, + const int32_t* cache_indir_data, + T* scaled_qk_data = nullptr) const; + void ComputeVxAttentionScoreWithBeams(T* output, + T* tmp_buffer, + const T* attention_probs, + const T* V, + int batch_size, + int past_sequence_length, + int max_sequence_length, + int v_head_size, + const T* past_value, + T* present_value, + const int32_t* cache_indir_data, + ThreadPool* tp) const; + Status Compute(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; + bool past_present_share_buffer_; + bool output_qk_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index efad33855328f..0e1c9ce7b108e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -10,45 +10,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { - int beam_width = 1; - - // Only NeoX style rotary embedding is supported - int rotary_embedding_dim = 0; - int t_step = 0; - - // Weather to use multihead attention(excludes matmul and bias) - bool is_mha = false; - bool is_cross_attention = false; - bool is_packed_qkv = false; - - // Useful to better use global memory bandwidth on certain CUDA architectures. - // Turned off by default for now until we fully understand performance implications - // for all types of workloads. - // Can be turned on by appropriate environment variable (see attention_common.h). - bool kv_data_in_flight = false; - - void* q = nullptr; - void* q_bias = nullptr; - - void* k = nullptr; - void* k_bias = nullptr; - - void* v = nullptr; - void* v_bias = nullptr; - - void* attention_bias = nullptr; - - void* k_cache = nullptr; - void* v_cache = nullptr; - - void* out = nullptr; - void* out_qk = nullptr; - - const int32_t* cache_indir = nullptr; - const int32_t* mask = nullptr; // [B, total_sequence_length] -}; - template < // The type of the inputs. Supported types: float and half. typename T, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index c706c6fc5ff5f..306d23169172a 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -787,14 +787,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(7, "beam_width", - "The beam width that is being used while decoding." + "The beam width that is being used while decoding. " "If not provided, the beam width will be assumed to be 1.", "M", OpSchema::Optional) .Input(8, "cache_indirection", - "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies" - "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", + "A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies " + "which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration", "M", OpSchema::Optional) .Output(0, @@ -902,15 +902,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(8, "beam_width", - "The beam width that is being used while decoding." + "The beam width that is being used while decoding. " "If not provided, the beam width will be assumed to be 1.", "M", OpSchema::Optional) .Input(9, "cache_indirection", // This input is useful for CUDA EP only. - "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies" - "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", + "A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies " + "which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration", "M", OpSchema::Optional) .Input(10, @@ -940,7 +940,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Output(3, "qk", - "normalized Q * K, of shape (batch_size, num_heads, 1, head_size). ", + "normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). ", "V", OpSchema::Optional) .TypeConstraint("V", {"tensor(float)"}, "Constrain qk output types to float32 tensors.")