Skip to content

Commit

Permalink
Revert uniform vec
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 3, 2023
1 parent c617826 commit 8fe65c3
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
}
};

const calculateInputIndexImpl = (numberOfTensors: number): string => `
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
fn calculateInputIndex(index: u32) -> u32 {
let sizeInConcatAxis = array<u32, ${numberOfTensors}u>(${sizeInConcatAxisStr});
for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) {
if (index < uniforms.sizeInConcatAxis[i]) {
if (index < sizeInConcatAxis[i]) {
return i;
}
}
Expand Down Expand Up @@ -103,8 +104,8 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
}
programUniforms.push({type: 'uint32', data: sizeInConcatAxis});
for (let i = 0; i < inputs.length; ++i) {
if (enableInputShapesUniforms[i]) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
Expand All @@ -120,13 +121,19 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
const output = outputVariable('output', dataType, outputShapeOrRank);

const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
${
shaderHelper.registerUniform('outputSize', 'u32')
.registerUniform(`sizeInConcatAxis`, `vec${inputs.length}<u32>`)
.declareVariables(...inputVars, output)}
${calculateInputIndexImpl(sizeInConcatAxis.length)}
${(function() {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output)
})()}
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
Expand All @@ -135,7 +142,8 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex != 0u) {
${indicesAxis} -= uniforms.sizeInConcatAxis[inputIndex - 1u];
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}
${assignOutputData(inputVars, output)}
Expand Down

0 comments on commit 8fe65c3

Please sign in to comment.