diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 4b5ca869f0dfb..7450ac2266c49 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -4,7 +4,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; @@ -33,14 +33,15 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const calculateInputIndexImpl = (numberOfTensors: number): string => ` +const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` fn calculateInputIndex(index: u32) -> u32 { - for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { if (index < sizeInConcatAxis[i]) { return i; } } - return ${numberOfTensors}u; + return ${numberOfTensors}; }`; const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelper) => { @@ -92,41 +93,57 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P const dataType = inputs[0].dataType; let previousSum = 0; + let inputDependencies = []; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); + inputDependencies.push('dims'); } const output = outputVariable('output', dataType, outputShape); const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} + ${shaderHelper.registerUniform('outputSize', 'u32').declareVariables(...inputVars, output)} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); - ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${(function fun() { + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return '' + })()} + ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)} }`; + const programUniforms = [{type: 'uint32', data: outputSize}]; + for (let i = 0; i < sizeInConcatAxis.length; i++) { + programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + } return { name: 'Concat', - shaderCache: {hint: `${axis}`}, + shaderCache: {hint: `${axis}`, inputDependencies: inputDependencies as ProgramInputTensorInfoDependency[]}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: programUniforms as ProgramUniform[], }), + getShaderSource, }; };