Skip to content

Commit

Permalink
DecoderMaskedMultiHeadAttention CPU kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
mindest committed Oct 2, 2024
1 parent 7880342 commit 10cd027
Show file tree
Hide file tree
Showing 7 changed files with 632 additions and 71 deletions.
10 changes: 5 additions & 5 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1175,9 +1175,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
</dl>
Expand All @@ -1192,7 +1192,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>qk</tt> (optional) : V</dt>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, head_size). </dd>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). </dd>
</dl>

#### Type Constraints
Expand Down Expand Up @@ -1261,9 +1261,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
</dl>

#### Outputs
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ struct AttentionParameters {
AttentionQkvFormat qkv_format;
};

struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
int beam_width = 1;

// Only NeoX style rotary embedding is supported
int rotary_embedding_dim = 0;
int t_step = 0;

// Whether to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
// for all types of workloads.
// Can be turned on by appropriate environment variable (see attention_common.h).
bool kv_data_in_flight = false;

void* q = nullptr;
void* q_bias = nullptr;

void* k = nullptr;
void* k_bias = nullptr;

void* v = nullptr;
void* v_bias = nullptr;

void* attention_bias = nullptr;

void* k_cache = nullptr;
void* v_cache = nullptr;

void* out = nullptr;
void* out_qk = nullptr;

const int32_t* cache_indir = nullptr;
const int32_t* mask = nullptr; // [B, total_sequence_length]
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters {
int batch_size;
Expand Down
50 changes: 30 additions & 20 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class AttentionCPUBase : public AttentionBase {
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const Tensor* attn_bias, // additive bias applied on scaled QK.
OpKernelContext* context) const {
OpKernelContext* context,
Tensor* scaled_qk = nullptr // output buffer for QK (if needed)
) const {

Check warning on line 41 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:41: Closing ) should be moved to the previous line [whitespace/parens] [2]
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

Expand Down Expand Up @@ -71,7 +73,7 @@ class AttentionCPUBase : public AttentionBase {

if (mask_data != nullptr) {
// Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
// Merge padding mask with causual mask, and broadcast to 3D (BxSxT).
// Merge padding mask with causal mask, and broadcast to 3D (BxSxT).
PrepareMask(mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
DUMP_CPU_TENSOR("Mask3D", static_cast<T*>(mask_data), batch_size, sequence_length, total_sequence_length);
Expand All @@ -85,6 +87,7 @@ class AttentionCPUBase : public AttentionBase {
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;
T* scaled_qk_data = scaled_qk != nullptr ? scaled_qk->MutableData<T>() : nullptr;

const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data<T>() : nullptr;
auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span<const int64_t>{};
Expand All @@ -97,7 +100,7 @@ class AttentionCPUBase : public AttentionBase {
static_cast<T*>(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims);
present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims, scaled_qk_data);

// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
Expand All @@ -117,23 +120,24 @@ class AttentionCPUBase : public AttentionBase {
// 1 x mask_data(B, N, S, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template <typename T>
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims // attention bias shape
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims, // attention bias shape
T* scaled_qk_data = nullptr // scaled output QK buffer
) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const size_t past_chunk_length = static_cast<size_t>(past_sequence_length) * head_size; // P x H
Expand Down Expand Up @@ -230,6 +234,12 @@ class AttentionCPUBase : public AttentionBase {
});
}

if (scaled_qk_data != nullptr) {
// Output the scaled Q*K^T if needed.
memcpy(scaled_qk_data, attention_probs,
SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T));
}

DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length);

// attention_probs(B, N, S, T) = Softmax(attention_probs)
Expand Down
Loading

0 comments on commit 10cd027

Please sign in to comment.