From 0ecdc6c54136f424c54e20fe248581e0e3d24de5 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 2 Nov 2023 09:43:25 +0800 Subject: [PATCH] [js/webgpu] Add unifroms support to concat op --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 46 +++++++++++++++++------ 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 4b5ca869f0df..26632c186286 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -35,8 +35,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const calculateInputIndexImpl = (numberOfTensors: number): string => ` fn calculateInputIndex(index: u32) -> u32 { - for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { - if (index < sizeInConcatAxis[i]) { + for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { + if (index < uniforms.sizeInConcatAxis[i]) { return i; } } @@ -92,40 +92,62 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P const dataType = inputs[0].dataType; let previousSum = 0; + const inputDependencies = []; + const inputShapeOrRanks = []; + const enableInputShapesUniforms = []; + let programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; + enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); + inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); + inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); + inputDependencies.push('rank'); + } + programUniforms.push({type: 'uint32', data: sizeInConcatAxis}); + for (let i = 0; i < inputs.length; ++i) { + if (enableInputShapesUniforms[i]) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); + } + } - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); } - const output = outputVariable('output', dataType, outputShape); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const indicesAxis = output.indicesGet('indices', adjustedAxis); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform(`sizeInConcatAxis`, `vec${inputs.length}`) + .declareVariables(...inputVars, output)} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateInputIndexImpl(sizeInConcatAxis.length)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { - ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; + ${indicesAxis} -= uniforms.sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)} }`; + return { name: 'Concat', - shaderCache: {hint: `${axis}`}, + shaderCache: {hint: `${axis}`, inputDependencies: inputDependencies as ProgramInputTensorInfoDependency[]}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: programUniforms as ProgramUniform[], }), getShaderSource, };