Skip to content

Commit

Permalink
[js/webgpu] Add unifroms support to concat op
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 2, 2023
1 parent 178f7ca commit e52f854
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
Expand All @@ -33,14 +33,15 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
}
};

const calculateInputIndexImpl = (numberOfTensors: number): string => `
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
fn calculateInputIndex(index: u32) -> u32 {
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
let sizeInConcatAxis = array<u32, ${numberOfTensors}u>(${sizeInConcatAxisStr});
for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) {
if (index < sizeInConcatAxis[i]) {
return i;
}
}
return ${numberOfTensors}u;
return ${numberOfTensors};
}`;

const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelper) => {
Expand Down Expand Up @@ -92,41 +93,64 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
const dataType = inputs[0].dataType;

let previousSum = 0;
let inputDependencies = [];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;

inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims);
inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims.length);
inputDependencies.push('rank');
}

const output = outputVariable('output', dataType, outputShape);
const output = outputVariable('output', dataType, outputShape.length);

const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(...inputVars, output)}
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
${calculateInputIndexImpl(sizeInConcatAxis.length)}
${(function() {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output)
})()}
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var indices = ${output.offsetToIndices('global_idx')};
let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex != 0u) {
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}
${assignOutputData(inputVars, output)}
}`;
let programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
for (let i = 0; i < sizeInConcatAxis.length; i++) {
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
}

for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));

return {
name: 'Concat',
shaderCache: {hint: `${axis}`},
shaderCache: {hint: `${axis}`, inputDependencies: inputDependencies as ProgramInputTensorInfoDependency[]},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: programUniforms as ProgramUniform[],
}),

getShaderSource,
};
};
Expand Down

0 comments on commit e52f854

Please sign in to comment.