Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…disablefp8
  • Loading branch information
xadupre committed Nov 17, 2023
2 parents cf102d8 + 1a29460 commit e066015
Show file tree
Hide file tree
Showing 41 changed files with 1,156 additions and 657 deletions.
4 changes: 1 addition & 3 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ option(onnxruntime_ENABLE_LTO "Enable link time optimization" OFF)
option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF)
option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF)
option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF)

#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf.
cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON)
option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF)
option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir")
option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF)
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF)
Expand Down
8 changes: 5 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr

Group Query Self/Cross Attention.

Supports different number of heads for q and kv.
Supports different number of heads for q and kv. Only supports causal or local attention.

#### Version

Expand All @@ -2396,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>scale</tt> : float</dt>
Expand Down Expand Up @@ -5021,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
Expand All @@ -5034,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dd>tensor with same shape as input.</dd>
</dl>

#### Type Constraints
Expand Down
10 changes: 9 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,8 @@ export const outputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, false, components);

export type UniformsArrayType = Array<{name: string; type: string}>;

/**
* A ShaderHelper is a helper class for generating WGSL code.
*/
Expand Down Expand Up @@ -697,6 +699,7 @@ export interface ShaderHelper {
* A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
*/
registerUniform(name: string, type: string): ShaderHelper;
registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper;
}

class ShaderHelperImpl implements ShaderHelper {
Expand Down Expand Up @@ -755,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}

registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper {
this.uniforms = this.uniforms.concat(additionalUniforms);
return this;
}

private indicesHelpers: IndicesHelper[] = [];
private uniforms: Array<{name: string; type: string}> = [];
private uniforms: UniformsArrayType = [];
private uniformDeclaration(): string {
if (this.uniforms.length === 0) {
return '';
Expand Down
81 changes: 60 additions & 21 deletions js/web/lib/wasm/jsep/webgpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, TensorInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';

import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';

export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
Expand Down Expand Up @@ -77,17 +77,26 @@ const fixStartEndValues =
};

const calculateInputIndicesImpl =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]):
string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
enableInputShapeUniforms: boolean): string =>
`fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
var inputIndices: ${input.type.indices};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
let input_shape_i = ${
enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'};
let steps_i = ${
enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'};
let signs_i = ${
enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'};
let starts_i = ${
enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'};
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex = outputIndex * steps[i] + starts[i] + carry;
carry = inputIndex / inputShape[i];
inputIndex = inputIndex % inputShape[i];
if (signs[i] < 0) {
inputIndex = inputShape[i] - inputIndex - 1u + starts[i];
var inputIndex = outputIndex * steps_i + starts_i + carry;
carry = inputIndex / input_shape_i;
inputIndex = inputIndex % input_shape_i;
if (signs_i < 0) {
inputIndex = input_shape_i - inputIndex - 1u + starts_i;
}
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
}
Expand All @@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice

const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps));

if (axes.length !== starts.length || axes.length !== ends.length) {
throw new Error('start, ends and axes should have the same number of elements');
}

if (axes.length !== inputShape.length) {
for (let i = 0; i < inputShape.length; ++i) {
if (!axes.includes(i)) {
Expand All @@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
array[i] = -step;
}
});
// Output rank is expected to be less than or equal to the input rank.
const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length);
const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims;

const outputShape = inputShape.slice(0);
axes.forEach((axis, _) => {
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
});
const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape;

const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType};

const output = outputVariable('output', inputs[0].dataType, outputShape);
const input = inputVariable('input', inputs[0].dataType, inputShape);
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank);
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [];
const uniforms: UniformsArrayType = [];
if (enableShapeUniforms) {
uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}<u32>` : 'u32'});
uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}<i32>` : 'i32'});
uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}<u32>` : 'u32'});
programUniforms.push({type: 'uint32', data: starts});
programUniforms.push({type: 'int32', data: signs});
programUniforms.push({type: 'uint32', data: steps});
}
uniforms.push({name: 'outputSize', type: 'u32'});
programUniforms.push({type: 'uint32', data: outputSize});
if (enableShapeUniforms) {
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
}

const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});
const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});
const ends = array<u32, ${ends.length}>(${ends.map(i => `${i}u`).join(',')});
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
${enableShapeUniforms ? '' : [
`const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});`,
`const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});`,
`const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});`,
`const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});`
].join('\n')}
${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let outputIndices = ${output.offsetToIndices('global_idx')};
let inputIndices = calculateInputIndices(outputIndices);
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
}`;
return {
name: 'Slice',
shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`},
shaderCache: {
hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` :
`${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`,
inputDependencies: [enableShapeUniforms ? 'rank' : 'dims']
},
getShaderSource,
getRunData: () => ({
outputs: [outputTensorInfo],
dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)},
programUniforms
})
};
};
Expand Down
23 changes: 23 additions & 0 deletions js/web/test/data/ops/slice.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@
}
]
},
{
"name": "Slice float32 with input[0] dim > 4",
"operator": "Slice",
"attributes": [],
"cases": [
{
"name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)",
"inputs": [
{
"data": [
0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692
],
"dims": [1, 1, 1, 1, 5],
"type": "float32"
},
{ "data": [3], "dims": [1], "type": "int64" },
{ "data": [4], "dims": [1], "type": "int64" },
{ "data": [4], "dims": [1], "type": "int64" }
],
"outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }]
}
]
},
{
"name": "Slice int32",
"operator": "Slice",
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters {
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool left_padding; // copies last token to last index if true
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
Expand Down
17 changes: 13 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
// Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
Expand All @@ -76,11 +86,10 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input_src + data_offset;
T* output_data = output_dest + data_offset;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;

// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct RotaryParameters {
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};

template <typename T>
Expand All @@ -33,8 +34,8 @@ Status CheckInputs(const T* input,

// Check input
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ",
if (input_dims.size() != 3 && input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ",
input_dims.size());
}
// Check position_ids
Expand Down Expand Up @@ -63,6 +64,14 @@ Status CheckInputs(const T* input,
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);

bool transposed = false;
if (input_dims.size() == 4) {
// input is [batch, num_heads, seq, head_size]
sequence_length = static_cast<int>(input_dims[2]);
hidden_size = static_cast<int>(input_dims[1]) * static_cast<int>(input_dims[3]);
transposed = true;
}
int max_sequence_length = static_cast<int>(cos_cache_dims[0]);
int head_size = static_cast<int>(cos_cache_dims[1]) * 2;
int num_heads = hidden_size / head_size;
Expand Down Expand Up @@ -111,11 +120,12 @@ Status CheckInputs(const T* input,
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
}

return Status::OK();
}

} // namespace rotary_embedding_helper
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params {
int seqlen_q_rounded = 0;
int seqlen_k_rounded = 0;
int d_rounded = 0;
int rotary_dim = 0;

// The scaling factors for the kernel.
float scale_softmax = 0.0;
Expand All @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params {
index_t knew_head_stride = 0;
index_t vnew_head_stride = 0;

// The cos and sin matrices for rotary embedding.
void* __restrict__ rotary_cos_ptr = nullptr;
void* __restrict__ rotary_sin_ptr = nullptr;

// The indices to index into the KV cache.
int* __restrict__ cache_batch_idx = nullptr;

// Local window size
int window_size_left = -1;
int window_size_right = -1;

bool is_bf16 = false;
bool is_causal = false;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative = true;

bool is_rotary_interleaved = false;

int num_splits = 0; // For split-KV version

const cudaDeviceProp* dprops = nullptr;
Expand Down
Loading

0 comments on commit e066015

Please sign in to comment.