Skip to content

Commit

Permalink
[CPU] SparseAttention op (microsoft#21110)
Browse files Browse the repository at this point in the history
Add SparseAttention cpu implementation.
- [x] Refactoring GQAAttentionBase
- [x] Add SparseAttention implementation
- [x] Add test cases

This is unfused version. Flash attention version will be added later.
  • Loading branch information
tianleiwu authored Jul 4, 2024
1 parent 30b6e82 commit 7d9b12a
Show file tree
Hide file tree
Showing 16 changed files with 1,034 additions and 165 deletions.
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5646,7 +5646,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain integer type.</dd>
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ Do not modify directly.*
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_row_indices:**M**<br> *in* block_col_indices:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class AttentionBase {
const Tensor* past_seq_len = nullptr) const;

int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

#pragma once

#include "attention_base.h"
#include "attention_helper.h"

#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/attention_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
Expand Down
31 changes: 24 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#pragma once

#include "attention_base.h"
#include "attention_helper.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/attention_helper.h"

#include "core/common/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
Expand All @@ -14,14 +14,31 @@
namespace onnxruntime {
namespace contrib {

class GQAAttentionBase : public AttentionBase {
class GQAAttentionBase {
protected:
GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
: AttentionBase(info, require_same_hidden_size) {}
GQAAttentionBase(const OpKernelInfo& info, bool has_local) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);

int local_window_size_;
bool do_rotary_;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
kv_num_heads_ = static_cast<int>(kv_num_heads);

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;

local_window_size_ = has_local ? static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) : -1;
}

int num_heads_; // number of attention heads of Q
int kv_num_heads_; // number of attention heads of K or V
float scale_; // the scaling factor applied before softmax
bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
int local_window_size_;

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
Expand Down
42 changes: 16 additions & 26 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "group_query_attention.h"
#include "group_query_attention_helper.h"
#include "attention_utils.h"
#include "rotary_embedding.h"
#include "rotary_embedding_helper.h"
#include "contrib_ops/cpu/bert/group_query_attention.h"
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
#include "contrib_ops/cpu/bert/rotary_helper.h"
#include "contrib_ops/cpu/bert/attention_utils.h"
#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"

#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
Expand Down Expand Up @@ -33,19 +34,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
GroupQueryAttention<float>);

template <typename T>
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
}
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
: OpKernel(info), GQAAttentionBase(info, true) {}

template <typename T>
Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
Expand Down Expand Up @@ -174,14 +164,14 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV<T>(tp,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
v_input,
v_rotary));
ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV<T>(tp,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
v_input,
v_rotary));
}
}

Expand Down
32 changes: 0 additions & 32 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,38 +263,6 @@ Status CheckInputs(const Tensor* query,

return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
}

template <typename T>
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
const T* input,
T* output) {
int seq_stride = head_size;
int head_stride = sequence_length * seq_stride;
int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;

const int loop_len = batch_size * sequence_length * kv_num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / kv_num_heads) / sequence_length);
const int s = static_cast<int>((ptr / kv_num_heads) % sequence_length);
const int n = static_cast<int>(ptr % kv_num_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
for (int i = 0; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}

} // namespace group_query_attention_helper
} // namespace contrib
} // namespace onnxruntime
47 changes: 47 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
namespace contrib {
namespace rotary_helper {

template <typename T>
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
const T* input,
T* output) {
int seq_stride = head_size;
int head_stride = sequence_length * seq_stride;
int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;

const int loop_len = batch_size * sequence_length * kv_num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / kv_num_heads) / sequence_length);
const int s = static_cast<int>((ptr / kv_num_heads) % sequence_length);
const int n = static_cast<int>(ptr % kv_num_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
for (int i = 0; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}

} // namespace rotary_helper
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
Expand Down Expand Up @@ -281,6 +282,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
Expand Down
Loading

0 comments on commit 7d9b12a

Please sign in to comment.