Skip to content

Commit

Permalink
[js/webgpu] Support GroupQueryAttention (#20237)
Browse files Browse the repository at this point in the history
TODOs:
1. Handle H * params.kvNumHeads greater than work group size limit.
2. Support BNSH kv cache.
  • Loading branch information
axinging authored May 13, 2024
1 parent 90d49cc commit 8c59cd4
Show file tree
Hide file tree
Showing 11 changed files with 1,059 additions and 15 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| GroupQueryAttention | com.microsoft(1+) | |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {fastGelu} from './ops/fast-gelu';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention';
import {instanceNorm} from './ops/instance-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
Expand Down Expand Up @@ -88,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
Expand Down
33 changes: 21 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,24 @@ export interface AttentionParameters {
headSize: number;
vHeadSize: number;
numHeads: number;
isUnidirectional: boolean;
kvNumHeads?: number;
nReps?: number;
isUnidirectional?: boolean;
pastPresentShareBuffer: boolean;
maskFilterValue: number;
maskFilterValue?: number;
maskType: AttentionMaskType;
scale: number;
broadcastResPosBias: boolean;
passPastInKv: boolean;
qkvFormat: AttentionQkvFormat;
isPastkvBSNH?: boolean;
}

export interface AttentionAttrs {
numHeads: number;
isUnidirectional: number;
maskFilterValue: number;
kvNumHeads?: number;
isUnidirectional?: number;
maskFilterValue?: number;
scale: number;
doRotary: number;
qkvHiddenSizes: number[];
Expand Down Expand Up @@ -443,17 +447,20 @@ const createVxAttentionScoreProgramInfo =
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
const TILE_SIZE = 12;
const dispatch = {
x: Math.ceil(params.vHeadSize / TILE_SIZE),
y: Math.ceil(params.sequenceLength / TILE_SIZE),
z: params.batchSize * params.numHeads
};

const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
{type: DataType.uint32, data: repeatedVHiddenSize}
];

const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
Expand Down Expand Up @@ -524,20 +531,22 @@ export const applyAttention =
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const outputPresentKey = context.outputCount > 1;
const outputPresentValue = context.outputCount > 2;
const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
const pastSequenceLength =
parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
// Concatinate pastKey and K to produce presentKey.
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
const key = outputPresentKey ? context.compute(
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
{inputs: concatKeyInputs, outputs: [1]})[0] :
k;
const key = parameters.kvNumHeads == null && outputPresentKey ?
context.compute(
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
{inputs: concatKeyInputs, outputs: [1]})[0] :
k;

// Concatinate pastValue and V to produce presentValue.
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatValueInputs = pastValue ? [pastValue, v] : [v];
const value = outputPresentValue ?
const value = parameters.kvNumHeads == null && outputPresentValue ?
context.compute(
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
{inputs: concatValueInputs, outputs: [2]})[0] :
Expand Down
Loading

0 comments on commit 8c59cd4

Please sign in to comment.