From 8c59cd4fce3c81d758256f2f13232ddd17bc8534 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 14 May 2024 00:43:37 +0800 Subject: [PATCH] [js/webgpu] Support GroupQueryAttention (#20237) TODOs: 1. Handle H * params.kvNumHeads greater than work group size limit. 2. Support BNSH kv cache. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 33 +- .../jsep/webgpu/ops/group-query-attention.ts | 346 ++++++++++ .../jsep/webgpu/ops/multihead-attentiion.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 4 +- .../test/data/ops/group-query-attention.jsonc | 616 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + .../js/bert/group_query_attention.cc | 24 + .../js/bert/group_query_attention.h | 43 ++ .../contrib_ops/js/js_contrib_kernels.cc | 2 + 11 files changed, 1059 insertions(+), 15 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts create mode 100644 js/web/test/data/ops/group-query-attention.jsonc create mode 100644 onnxruntime/contrib_ops/js/bert/group_query_attention.cc create mode 100644 onnxruntime/contrib_ops/js/bert/group_query_attention.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index c93f4f3cce68..3af4942c2e4a 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -54,6 +54,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| GroupQueryAttention | com.microsoft(1+) | | | HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 7ec0d6e0ff64..78e4871dcec8 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -18,6 +18,7 @@ import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; +import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; @@ -88,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 94ad67d3c3b0..fda2ff64b0ac 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -46,20 +46,24 @@ export interface AttentionParameters { headSize: number; vHeadSize: number; numHeads: number; - isUnidirectional: boolean; + kvNumHeads?: number; + nReps?: number; + isUnidirectional?: boolean; pastPresentShareBuffer: boolean; - maskFilterValue: number; + maskFilterValue?: number; maskType: AttentionMaskType; scale: number; broadcastResPosBias: boolean; passPastInKv: boolean; qkvFormat: AttentionQkvFormat; + isPastkvBSNH?: boolean; } export interface AttentionAttrs { numHeads: number; - isUnidirectional: number; - maskFilterValue: number; + kvNumHeads?: number; + isUnidirectional?: number; + maskFilterValue?: number; scale: number; doRotary: number; qkvHiddenSizes: number[]; @@ -443,17 +447,20 @@ const createVxAttentionScoreProgramInfo = (_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters, pastSequenceLength: number) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; - const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + const nReps = params.nReps ? params.nReps : 1; + const repeatedVHiddenSize = params.vHiddenSize * nReps; + const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; const TILE_SIZE = 12; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.sequenceLength / TILE_SIZE), z: params.batchSize * params.numHeads }; + const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, - {type: DataType.uint32, data: params.vHiddenSize} + {type: DataType.uint32, data: repeatedVHiddenSize} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; @@ -524,20 +531,22 @@ export const applyAttention = relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { const outputPresentKey = context.outputCount > 1; const outputPresentValue = context.outputCount > 2; - const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0; + const pastSequenceLength = + parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; // Concatinate pastKey and K to produce presentKey. const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; const concatKeyInputs = pastKey ? [pastKey, k] : [k]; - const key = outputPresentKey ? context.compute( - createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType), - {inputs: concatKeyInputs, outputs: [1]})[0] : - k; + const key = parameters.kvNumHeads == null && outputPresentKey ? + context.compute( + createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType), + {inputs: concatKeyInputs, outputs: [1]})[0] : + k; // Concatinate pastValue and V to produce presentValue. const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; const concatValueInputs = pastValue ? [pastValue, v] : [v]; - const value = outputPresentValue ? + const value = parameters.kvNumHeads == null && outputPresentValue ? context.compute( createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType), {inputs: concatValueInputs, outputs: [2]})[0] : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts new file mode 100644 index 000000000000..d03820591057 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -0,0 +1,346 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {maybeTransposeToBNSHAndAddBias} from './multihead-attentiion'; +import {createTileProgramInfo} from './tile'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const pastKey = inputs[3]; + const pastValue = inputs[4]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + const hasPastKey = pastKey && pastKey.dims.length !== 0; + const hasPastValue = pastValue && pastValue.dims.length !== 0; + // TODO : this should be from attributes. + const isPastkvBSNH = true; + if (hasPastKey && hasPastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + if (isPastkvBSNH) { + // For BSNH + pastSequenceLength = pastKey.dims[1]; + maxSequenceLength = pastKey.dims[1]; + } else { + // For BNSH + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } + } else if (hasPastKey || hasPastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (query.dims[2] % key.dims[2] !== 0) { + throw new Error('Dimension 2 of "query" should be a multiple of "key"'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + const maskType: AttentionMaskType = AttentionMaskType.none; + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!), + numHeads: attributes.numHeads, + kvNumHeads: attributes.kvNumHeads, + nReps: attributes.numHeads / attributes.kvNumHeads!, + pastPresentShareBuffer: false, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + isPastkvBSNH, + }; +}; + +const createConcatProgramInfo = + (a: TensorView, b: TensorView|undefined, dataType: DataType, params: AttentionParameters): ProgramInfo => { + const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; + const component = 4; + const outputSize = ShapeUtil.size(outputShape) / component; + const presentSequenceLength = params.totalSequenceLength; + const output = outputVariable('present_kv', dataType, outputShape.length, component); + const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); + const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; + + const H = Math.ceil(params.headSize / component); + const dispatch = {x: presentSequenceLength, y: a.dims[0], z: 1}; + + const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: params.pastSequenceLength}, + {type: DataType.uint32, data: params.kvSequenceLength}, + {type: DataType.uint32, data: params.totalSequenceLength} + ]; + + const inputs = [inputA]; + if (inputB) { + programUniforms.push( + ...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(b!.dims), + ...createTensorShapeVariables(outputShape)); + inputs.push(inputB); + } else { + programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); + } + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'past_seqlen', type: 'u32'}, {name: 'new_seqlen', type: 'u32'}, + {name: 'present_seqlen', type: 'u32'} + ]; + + const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; + var past_head_stride = uniforms.past_seqlen * H; + if (is_bsnh) { + past_head_stride = H; + } + let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset];`; + const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; + let new_row_stride = num_heads * H; + let new_head_stride = H; + let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset];`; + const concatStr = b ? `if (s < past_seqlen) { + ${pastStr} + } else if (s < past_seqlen + uniforms.new_seqlen) { + ${newStr} + }` : + `if (s < past_seqlen + uniforms.new_seqlen) { + ${newStr} + }`; + + // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. + const getShaderSource = (shaderHelper: ShaderHelper) => ` + + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)} + ${shaderHelper.mainStart([ + H, params.kvNumHeads!, 1 + ])} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var indices = ${output.offsetToIndices('global_idx')}; + let h = local_id.x; + let n = local_id.y; + let s = workgroup_id.x; + let b = workgroup_id.y; + let num_heads = ${params.kvNumHeads!}u; + let H = ${H}u; + + let present_seqlen = uniforms.present_seqlen; + let present_batch_stride = present_seqlen * num_heads * H; + var row_stride = H; + let is_bsnh = ${params.isPastkvBSNH}; + + if (is_bsnh) { + row_stride = num_heads * H; + } + var present_head_stride = present_seqlen * H; + if (is_bsnh) { + present_head_stride = H; + } + + let past_seqlen = uniforms.past_seqlen; + + let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + ${concatStr} + }`; + + return { + name: 'ConcatPastNew', + shaderCache: {hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: dispatch, + programUniforms, + }), + getShaderSource, + }; + }; + +export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const maybeExpandAndTransposeToBNSH = + (context: ComputeContext, input: TensorView, pastKV: TensorView|undefined, params: AttentionParameters, + outputIndex: number) => { + let reshapedInput = input; + const numHeads = params.kvNumHeads!; + const nReps = params.nReps!; + if (input.dims.length === 3 && params.kvSequenceLength !== 0) { + reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); + } + + if (pastKV) { + reshapedInput = context.compute( + createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), + {inputs: [reshapedInput, pastKV], outputs: [params.isPastkvBSNH ? outputIndex : -1]})[0]; + } else { + reshapedInput = context.compute( + createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), + {inputs: [reshapedInput], outputs: [params.isPastkvBSNH ? outputIndex : -1]})[0]; + } + if (nReps !== 1) { + reshapedInput = context.compute( + createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {inputs: [reshapedInput], outputs: [-1]})[0]; + reshapedInput = + reshapedInput.reshape([params.batchSize, params.totalSequenceLength, numHeads * nReps, params.headSize]); + } + + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + }; + +export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], undefined, + 0); + const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; + const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; + const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1); + const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2); + applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts index 4b18a41ccbeb..09fadea66fa1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts @@ -286,7 +286,7 @@ const addBiasTranspose = {inputs: [qkv, bias], outputs: [-1]})[0]; }; -const maybeTransposeToBNSHAndAddBias = +export const maybeTransposeToBNSHAndAddBias = (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, input: TensorView, bias?: TensorView, biasOffset?: number) => { // const newDims = []; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index f9728575fe07..d58d71a28c27 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -47,9 +47,9 @@ const getOutputShape = (inputShape: readonly number[], repeats: readonly number[ return outputShape; }; -export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { +export const createTileProgramInfo = (inputs: readonly TensorView[], shape?: number[]): ProgramInfo => { const inputShape = inputs[0].dims; - const repeats: readonly number[] = getRepeats(inputs[1]); + const repeats: readonly number[] = shape == null ? getRepeats(inputs[1]) : shape; const outputShape = getOutputShape(inputShape, repeats); const outputSize = ShapeUtil.size(outputShape); diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc new file mode 100644 index 000000000000..2a4b26507845 --- /dev/null +++ b/js/web/test/data/ops/group-query-attention.jsonc @@ -0,0 +1,616 @@ +[ + { + "name": "GroupQueryAttention Basic", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // key, BS* + { + "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "dims": [1, 3, 8], + "type": "float32" + }, + // past key, BS* + { + "data": null, + "type": "float32" + }, + // past value, BS* + { + "data": null, + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, + 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + // present key, BS* + "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "dims": [1, 3, 2, 4], + "type": "float32" + }, + { + // present value, BS* + "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "dims": [1, 3, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention Scale", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" }, + { "name": "scale", "data": 2.0, "type": "float" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + { + "data": [1, 9, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + // past key, BS* + { + "data": null, + "type": "float32" + }, + // past value, BS* + { + "data": null, + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, + 1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + { + // present key, BS* + "data": [1, 9, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + // present value, BS* + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + + { + "name": "GroupQueryAttention, different sequence length", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + { + "data": [1, 9, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + // past key, BS* + { + "data": null, + "type": "float32" + }, + // past value, BS* + { + "data": null, + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823, + 1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, + 1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1, + 1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + { + "data": [1, 9, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention Basic, q k v same head number", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + "data": [ + 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, + 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, + 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // past key, BS* + { + "data": null, + "type": "float32" + }, + // past value, BS* + { + "data": null, + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, + 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + { + "data": [ + 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, + 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + ], + "dims": [1, 3, 4, 4], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, + 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + ], + "dims": [1, 3, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention, no past kv, used as reference", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112 + ], + "dims": [1, 7, 16], + "type": "float32" + }, + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + ], + "dims": [1, 7, 8], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + ], + "dims": [1, 7, 8], + "type": "float32" + }, + // past key, BS* + { + "data": null, + "type": "float32" + }, + // past value, BS* + { + "data": null, + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, + 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, + 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, + 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, + 52, 53, 54, 55, 52, 53, 54, 55 + ], + "dims": [1, 7, 16], + "type": "float32" + }, + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + ], + "dims": [1, 7, 2, 4], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + ], + "dims": [1, 7, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112 + ], + "dims": [1, 7, 16], + "type": "float32" + }, + // new key, BS* + { + "data": [ + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + ], + "dims": [1, 6, 8], + "type": "float32" + }, + // new value, BS* + { + "data": [ + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + ], + "dims": [1, 6, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 2, 4], + "type": "float32" + }, + // past value, BS* + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 2, 4], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, + 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, + 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, + 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, + 52, 53, 54, 55, 52, 53, 54, 55 + ], + "dims": [1, 7, 16], + "type": "float32" + }, + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + ], + "dims": [1, 7, 2, 4], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + ], + "dims": [1, 7, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112 + ], + "dims": [1, 7, 16], + "type": "float32" + }, + // new key, BS* + { + "data": [ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + ], + "dims": [1, 5, 8], + "type": "float32" + }, + // new value, BS* + { + "data": [ + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + ], + "dims": [1, 5, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 2, 4], + "type": "float32" + }, + // past value, BS* + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 2, 2, 4], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, + 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, + 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, + 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, + 52, 53, 54, 55, 52, 53, 54, 55 + ], + "dims": [1, 7, 16], + "type": "float32" + }, + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 + ], + "dims": [1, 7, 2, 4], + "type": "float32" + }, + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 + ], + "dims": [1, 7, 2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c212e4c4a5fa..cfb43f01a824 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1361,6 +1361,7 @@ "gemm.jsonc", "global-average-pool.jsonc", "greater.jsonc", + "group-query-attention.jsonc", "instance-norm.jsonc", "less.jsonc", "log.jsonc", diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.cc b/onnxruntime/contrib_ops/js/bert/group_query_attention.cc new file mode 100644 index 000000000000..3bdd3edcc598 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "group_query_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + GroupQueryAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + GroupQueryAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.h b/onnxruntime/contrib_ops/js/bert/group_query_attention.h new file mode 100644 index 000000000000..7553883a2478 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class GroupQueryAttention : public JsKernel { + public: + explicit GroupQueryAttention(const OpKernelInfo& info) + : JsKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({ + "numHeads" : $1, + "kvNumHeads" : $2, + "scale" : $3, + }), + static_cast(num_heads_), + static_cast(kv_num_heads_), + static_cast(scale_)); + } + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // number of k and v heads + float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size) +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index a6f8aebc2d1e..9d8f79c67d8a 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -13,6 +13,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSp class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding); @@ -34,6 +35,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,