Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] DecoderMaskedMultiHeadAttention CPU kernel. #22292

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1175,9 +1175,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
</dl>
Expand All @@ -1192,7 +1192,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>qk</tt> (optional) : V</dt>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, head_size). </dd>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). </dd>
</dl>

#### Type Constraints
Expand Down Expand Up @@ -1261,9 +1261,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
</dl>

#### Outputs
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ struct AttentionParameters {
AttentionQkvFormat qkv_format;
};

struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
int beam_width = 1;

// Only NeoX style rotary embedding is supported
int rotary_embedding_dim = 0;
int t_step = 0;

// Whether to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
// for all types of workloads.
// Can be turned on by appropriate environment variable (see attention_common.h).
bool kv_data_in_flight = false;

void* q = nullptr;
void* q_bias = nullptr;

void* k = nullptr;
void* k_bias = nullptr;

void* v = nullptr;
void* v_bias = nullptr;

void* attention_bias = nullptr;

void* k_cache = nullptr;
void* v_cache = nullptr;

void* out = nullptr;
void* out_qk = nullptr;

const int32_t* cache_indir = nullptr;
const int32_t* mask = nullptr; // [B, total_sequence_length]
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters {
int batch_size;
Expand Down
50 changes: 30 additions & 20 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const Tensor* attn_bias, // additive bias applied on scaled QK.
OpKernelContext* context) const {
OpKernelContext* context,
Tensor* scaled_qk = nullptr // output buffer for QK (if needed)
) const {

Check warning on line 41 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:41: Closing ) should be moved to the previous line [whitespace/parens] [2]
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

Expand Down Expand Up @@ -71,7 +73,7 @@

if (mask_data != nullptr) {
// Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
// Merge padding mask with causual mask, and broadcast to 3D (BxSxT).
// Merge padding mask with causal mask, and broadcast to 3D (BxSxT).
PrepareMask(mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
DUMP_CPU_TENSOR("Mask3D", static_cast<T*>(mask_data), batch_size, sequence_length, total_sequence_length);
Expand All @@ -85,6 +87,7 @@
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;
T* scaled_qk_data = scaled_qk != nullptr ? scaled_qk->MutableData<T>() : nullptr;

const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data<T>() : nullptr;
auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span<const int64_t>{};
Expand All @@ -97,7 +100,7 @@
static_cast<T*>(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims);
present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims, scaled_qk_data);

// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
Expand All @@ -117,23 +120,24 @@
// 1 x mask_data(B, N, S, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template <typename T>
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims // attention bias shape
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims, // attention bias shape
T* scaled_qk_data = nullptr // scaled output QK buffer
) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const size_t past_chunk_length = static_cast<size_t>(past_sequence_length) * head_size; // P x H
Expand Down Expand Up @@ -230,6 +234,12 @@
});
}

if (scaled_qk_data != nullptr) {
// Output the scaled Q*K^T if needed.
memcpy(scaled_qk_data, attention_probs,
SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T));
}

DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length);

// attention_probs(B, N, S, T) = Softmax(attention_probs)
Expand Down
Loading
Loading