Skip to content

Commit

Permalink
Rename to fix format issue 120
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Jul 26, 2024
1 parent 7fff598 commit f458ee6
Showing 1 changed file with 22 additions and 26 deletions.
48 changes: 22 additions & 26 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ export const computeConv3DInfo =
export const createConv3DNaiveProgramInfo =
(inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[],
filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => {
const isChannelsLast = dataFormat === 'channelsLast';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
const isChannelLast = dataFormat === 'channelsLast';
const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1];
// TODO: enable vec4.
const isVec4 = false;
const workGroupSize: [number, number, number] = [64, 1, 1];
Expand All @@ -231,7 +231,7 @@ export const createConv3DNaiveProgramInfo =

LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`);

const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims},
Expand Down Expand Up @@ -271,7 +271,7 @@ export const createConv3DNaiveProgramInfo =
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : array<u32, 5>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[${isChannelsLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
isVec4 ? '/ 4' : ''}];
}`;
}
Expand All @@ -294,28 +294,24 @@ export const createConv3DNaiveProgramInfo =
let coords = ${output.offsetToIndices('global_idx')};
let batch = ${getElementAt('coords', 0, x.rank)};
let d2 = ${
isChannelsLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)};
isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)};
let xFRCCorner = vec3<u32>(${
isChannelsLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)},
${isChannelsLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)},
${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
${
isChannelsLast ? getElementAt('coords', 3, x.rank) :
getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads;
isChannelLast ? getElementAt('coords', 3, x.rank) :
getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads;
let xFCorner = xFRCCorner.x;
let xRCorner = xFRCCorner.y;
let xCCorner = xFRCCorner.z;
let xShapeY = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) :
getElementAt('uniforms.x_shape', 2, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)};
let xShapeZ = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) :
getElementAt('uniforms.x_shape', 3, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)};
let xShapeW = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) :
getElementAt('uniforms.x_shape', 4, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)};
let xShapeU = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) :
getElementAt('uniforms.x_shape', 1, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)};
let inputDepthNearestVec4 = (xShapeU / 4) * 4;
let inputDepthVec4Remainder = xShapeU % 4;
Expand All @@ -340,13 +336,13 @@ export const createConv3DNaiveProgramInfo =
for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) {
${
isChannelsLast ? `let xValues = vec4<f32>(
isChannelLast ? `let xValues = vec4<f32>(
getX(batch, xF, xR, xC, d1),
getX(batch, xF, xR, xC, d1 + 1),
getX(batch, xF, xR, xC, d1 + 2),
getX(batch, xF, xR, xC, d1 + 3));
` :
`let xValues = vec4<f32>(
`let xValues = vec4<f32>(
getX(batch, d1, xF, xR, xC),
getX(batch, d1 + 1, xF, xR, xC),
getX(batch, d1 + 2, xF, xR, xC),
Expand All @@ -361,17 +357,17 @@ export const createConv3DNaiveProgramInfo =
}
if (inputDepthVec4Remainder == 1) {
${
isChannelsLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);` :
`value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
`value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);`}
} else if (inputDepthVec4Remainder == 2) {
${
isChannelsLast ? `let xValues = vec2<f32>(
isChannelLast ? `let xValues = vec2<f32>(
getX(batch, xF, xR, xC, inputDepthNearestVec4),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1));
` :
`let xValues = vec2<f32>(
`let xValues = vec2<f32>(
getX(batch, inputDepthNearestVec4, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC));
`}
Expand All @@ -381,12 +377,12 @@ export const createConv3DNaiveProgramInfo =
value += dot(xValues, wValues);
} else if (inputDepthVec4Remainder == 3) {
${
isChannelsLast ? `let xValues = vec3<f32>(
isChannelLast ? `let xValues = vec3<f32>(
getX(batch, xF, xR, xC, inputDepthNearestVec4),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2));
` :
`let xValues = vec3<f32>(
`let xValues = vec3<f32>(
getX(batch, inputDepthNearestVec4, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC));
Expand All @@ -408,7 +404,7 @@ export const createConv3DNaiveProgramInfo =
return {
name: 'Conv3DNaive',
shaderCache:
{hint: `${attributes.cacheKey};${isChannelsLast};${innerElementSize};${hasBias}`, inputDependencies},
{hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
Expand Down

0 comments on commit f458ee6

Please sign in to comment.