Skip to content

Commit

Permalink
[ORT 1.18.0 Release] Cherry pick 3rd/Final round (microsoft#20677)
Browse files Browse the repository at this point in the history
Co-authored-by: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com>
Co-authored-by: rachguo <rachguo@rachguos-Mac-mini.local>
Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: George Wu <jywu@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Jian Chen <cjian@microsoft.com>
  • Loading branch information
8 people authored May 15, 2024
1 parent ed349b9 commit 4573740
Show file tree
Hide file tree
Showing 28 changed files with 892 additions and 1,290 deletions.
4 changes: 4 additions & 0 deletions csharp/OnnxRuntime.CSharp.proj
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ CMake creates a target to this project
<PropertyGroup>
<!-- If we create multiple nuget packages in one job, major package and dependent packages version should be the same-->
<!-- CurrentDate and CurrentTime are only used for dev packages-->
<CurrentDate Condition=" '$(BuildDate)'!='' ">$(BuildDate)</CurrentDate>
<CurrentTime Condition=" '$(BuildTime)'!='' ">$(BuildTime)</CurrentTime>
<CurrentDate Condition="'$(CurrentDate)'==''">$([System.DateTime]::UtcNow.ToString(yyyyMMdd))</CurrentDate>
<CurrentTime Condition="'$(CurrentTime)'==''">$([System.DateTime]::UtcNow.ToString(hhmm))</CurrentTime>


</PropertyGroup>

<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
Expand Down
46 changes: 33 additions & 13 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5553,11 +5553,29 @@ This version of the operator has been available since version 1 of the 'com.micr
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).

Padding shall be on the right side.
The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain
paddings at the right side when different layout has different number of non-zeros in block mask.

When do_rotary is True, cos_cache and sin_cache are required.
An example of block mask with 2 layouts where each layout is 4 x 4 blocks:
[[[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 1]],

[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 0, 1, 1]]]

The corresponding CSR format:
block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]]
block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]]

When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos
or sin cache can be different from the maximum sequence length used by kv cache.

Only supports unidirectional attention with cache of past key and value in linear buffers.

For performance, past_key and present_key share same memory buffer, and past_value and present_value too.

#### Version
Expand All @@ -5581,7 +5599,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of tokens per sparse block. Choices: 16, 32, 64, 128</dd>
</dl>

#### Inputs (8 - 10)
#### Inputs (9 - 11)

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -5590,20 +5608,22 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Key with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
<dt><tt>block_mask</tt> : M</dt>
<dd>block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
<dt><tt>past_key</tt> : T</dt>
<dd>Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
<dt><tt>past_value</tt> : T</dt>
<dd>Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)</dd>
<dt><tt>block_row_indices</tt> : M</dt>
<dd>The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1).The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
<dt><tt>block_col_indices</tt> : M</dt>
<dd>The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks).The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.</dd>
<dt><tt>key_total_sequence_lengths</tt> : M</dt>
<dd>1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>Cos cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
<dd>Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>Sin cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
<dd>Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs
Expand All @@ -5612,9 +5632,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)</dd>
<dt><tt>present_key</tt> : T</dt>
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
<dt><tt>present_value</tt> : T</dt>
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).</dd>
</dl>

#### Type Constraints
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ Do not modify directly.*
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_mask:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_row_indices:**M**<br> *in* block_col_indices:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,15 @@ struct SparseAttentionParameters {
bool rotary_interleaved; // whether to use interleaved rotary embedding
int rotary_dim; // rotary embedding dimension
int sparse_block_size; // block size for sparse attention
int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask
int num_sparse_layout; // number of sparse layout
int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
float scale; // scaling factor applied prior to softmax
bool is_packed_qkv; // whether qkv is packed
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
int max_sequence_length; // max sequence length allowed
int max_sequence_length; // max sequence length for sparse layout
int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
int max_cache_sequence_length; // max sequence length for kv cache buffer
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
};

Expand Down
63 changes: 22 additions & 41 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
#include "contrib_ops/cuda/sparse/sparse_attention.h"
#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
#include "contrib_ops/cuda/sparse/block_mask.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
#include "core/platform/env_var_utils.h"
Expand All @@ -26,7 +25,7 @@ namespace cuda {
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \
.MayInplace(3, 1) \
.MayInplace(4, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 6), \
.InputMemoryType(OrtMemTypeCPUInput, 7), \
SparseAttention<T>);

REGISTER_KERNEL_TYPED(MLFloat16)
Expand Down Expand Up @@ -77,15 +76,16 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* value = context->Input<Tensor>(2);
const Tensor* past_key = context->Input<Tensor>(3);
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* block_mask = context->Input<Tensor>(5);
const Tensor* total_seq_len = context->Input<Tensor>(6);
const Tensor* seqlens_k_total = context->Input<Tensor>(7);
const Tensor* cos_cache = context->Input<Tensor>(8);
const Tensor* sin_cache = context->Input<Tensor>(9);
const Tensor* block_row_indices = context->Input<Tensor>(5);
const Tensor* block_col_indices = context->Input<Tensor>(6);
const Tensor* total_seq_len = context->Input<Tensor>(7);
const Tensor* seqlens_k_total = context->Input<Tensor>(8);
const Tensor* cos_cache = context->Input<Tensor>(9);
const Tensor* sin_cache = context->Input<Tensor>(10);

SparseAttentionParameters parameters;

// Parameters from node attribute
// Parameters from node attribute shall be set before calling CheckInputs
parameters.sparse_block_size = sparse_block_size_;
parameters.num_heads = num_heads_;
parameters.kv_num_heads = kv_num_heads_;
Expand All @@ -101,7 +101,8 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
past_value,
cos_cache,
sin_cache,
block_mask,
block_row_indices,
block_col_indices,
seqlens_k_total,
total_seq_len));
// Some limitations of CUDA kernels
Expand Down Expand Up @@ -177,7 +178,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, output_shape);

std::vector<int64_t> present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
parameters.batch_size, parameters.kv_num_heads, parameters.max_cache_sequence_length, parameters.head_size};
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);
Expand All @@ -188,13 +189,12 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.block_mask = block_mask->Data<int32_t>();
data.past_key = reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data<int32_t>();
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.present_key = reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = reinterpret_cast<CudaT*>(present_value->MutableData<T>());

// Check past and present share buffer.
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
Expand All @@ -214,45 +214,26 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
// Currently, we use same block size in kernel.
// TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask).
data.kernel_layout.block_size = parameters.sparse_block_size;
data.kernel_layout.mask = data.block_mask;
data.kernel_layout.num_layout = parameters.num_sparse_layout;
data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size;
data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size;

// Allocate buffer for CSR col and row indices.
onnxruntime::Stream* stream = context->GetComputeStream();
int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows;
auto csr_col_indices_buffer = GetScratchBuffer<int>(static_cast<size_t>(dense_blocks), stream);
auto csr_row_indices_buffer = GetScratchBuffer<int>(
static_cast<size_t>(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream);

data.kernel_layout.csr_col_indices = reinterpret_cast<const int*>(csr_col_indices_buffer.get());
data.kernel_layout.csr_row_indices = reinterpret_cast<const int*>(csr_row_indices_buffer.get());

ConvertMaskToCSR(cuda_stream,
data.kernel_layout.mask,
data.kernel_layout.num_layout,
data.kernel_layout.num_rows,
data.kernel_layout.num_cols,
csr_row_indices_buffer.get(),
csr_col_indices_buffer.get(),
device_prop.maxThreadsPerBlock);
data.kernel_layout.csr_col_indices = block_col_indices->Data<int32_t>();
data.kernel_layout.csr_row_indices = block_row_indices->Data<int32_t>();

size_t rotary_buffer_bytes = 0;
if (do_rotary_) {
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads *
parameters.sequence_length * parameters.head_size;
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
}
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
onnxruntime::Stream* stream = context->GetComputeStream();
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, stream);
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());

size_t transposed_q_bytes = 0;
if (!parameters.is_packed_qkv) {
transposed_q_bytes = parameters.batch_size * parameters.sequence_length *
parameters.num_heads * parameters.head_size * sizeof(T);
}
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, context->GetComputeStream());
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, stream);
if (transposed_q_buffer) {
data.transposed_q_buffer = reinterpret_cast<CudaT*>(transposed_q_buffer.get());
}
Expand All @@ -263,7 +244,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
(parameters.num_heads + 2 * parameters.kv_num_heads) *
parameters.head_size * sizeof(T));
}
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, stream);
if (unpacked_qkv_buffer) {
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
Expand Down Expand Up @@ -327,7 +308,7 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}

v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, context->GetComputeStream());
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, stream);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned,
sizeof(int32_t) * v2_kernel_buffer_size,
cudaMemcpyHostToDevice, cuda_stream));
Expand Down
Loading

0 comments on commit 4573740

Please sign in to comment.