diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 9f0f7ae92d66..01201110b71c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -33,10 +33,11 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const calculateInputIndexImpl = (numberOfTensors: number): string => ` +const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` fn calculateInputIndex(index: u32) -> u32 { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { - if (index < uniforms.sizeInConcatAxis[i]) { + if (index < sizeInConcatAxis[i]) { return i; } } @@ -103,8 +104,8 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); + programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); } - programUniforms.push({type: 'uint32', data: sizeInConcatAxis}); for (let i = 0; i < inputs.length; ++i) { if (enableInputShapesUniforms[i]) { programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); @@ -120,13 +121,19 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P const output = outputVariable('output', dataType, outputShapeOrRank); const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform(`sizeInConcatAxis`, `vec${inputs.length}`) - .declareVariables(...inputVars, output)} - ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${(function() { + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output) + })()} + + ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} @@ -135,7 +142,8 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { - ${indicesAxis} -= uniforms.sizeInConcatAxis[inputIndex - 1u]; + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)}