Skip to content

Commit

Permalink
[js/webgpu] Add activation for conv3d naive
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Jul 24, 2024
1 parent dd010ed commit 95eecda
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
23 changes: 15 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
import {typeSnippet} from './activation_util';

const arrayProduct = (arr: number[]) => {
let product = 1;
Expand Down Expand Up @@ -235,6 +237,7 @@ export const createConv3DNaiveProgramInfo =
{type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides},
{type: DataType.uint32, data: attributes.dilations}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const hasBias = inputs.length === 3;
Expand All @@ -251,6 +254,7 @@ export const createConv3DNaiveProgramInfo =
{name: 'strides', type: 'u32', length: attributes.strides.length},
{name: 'dilations', type: 'u32', length: attributes.dilations.length}
];
appendActivationUniforms(attributes, uniforms);
// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
Expand All @@ -270,6 +274,8 @@ export const createConv3DNaiveProgramInfo =
isVec4 ? '/ 4' : ''}];
}`;
}
const resType = typeSnippet(innerElementSize, t);
const applyActivation = getActivationSnippet(attributes, resType, t);

return `
${declareFunctions}
Expand Down Expand Up @@ -308,7 +314,7 @@ export const createConv3DNaiveProgramInfo =
let inputDepthNearestVec4 = (xShapeU / 4) * 4;
let inputDepthVec4Remainder = xShapeU % 4;
var dotProd = 0.0;
var value = 0.0;
for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {
let xF = xFCorner + wF * uniforms.dilations[0];
if (xF < 0 || xF >= xShapeY) {
Expand Down Expand Up @@ -346,13 +352,13 @@ export const createConv3DNaiveProgramInfo =
getW(d2, d1 + 1, wF, wR, wC),
getW(d2, d1 + 2, wF, wR, wC),
getW(d2, d1 + 3, wF, wR, wC));
dotProd += dot(xValues, wValues);
value += dot(xValues, wValues);
}
if (inputDepthVec4Remainder == 1) {
${
isChannelsLast ? `dotProd += getX(batch, xF, xR, xC, inputDepthNearestVec4)
isChannelsLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);` :
`dotProd += getX(batch, inputDepthNearestVec4, xF, xR, xC)
`value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);`}
} else if (inputDepthVec4Remainder == 2) {
${
Expand All @@ -367,7 +373,7 @@ export const createConv3DNaiveProgramInfo =
let wValues = vec2<f32>(
getW(d2, inputDepthNearestVec4, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC));
dotProd += dot(xValues, wValues);
value += dot(xValues, wValues);
} else if (inputDepthVec4Remainder == 3) {
${
isChannelsLast ? `let xValues = vec3<f32>(
Expand All @@ -384,13 +390,14 @@ export const createConv3DNaiveProgramInfo =
getW(d2, inputDepthNearestVec4, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC));
dotProd += dot(xValues, wValues);
value += dot(xValues, wValues);
}
}
}
}
${hasBias ? 'dotProd = dotProd + getBiasByOutputCoords(coords)' : ''};
result[global_idx] = f32(dotProd);
${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''};
${applyActivation}
result[global_idx] = f32(value);
}`;
};
return {
Expand Down
75 changes: 75 additions & 0 deletions js/web/test/data/ops/fused-conv3dncdhw.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
[
{
"name": "conv3d, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu",
"operator": "FusedConv",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "activation", "data": "Relu", "type": "string" },
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
{ "name": "auto_pad", "data": "VALID", "type": "string" },
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.25, 0.5, 0.75, 1],
"dims": [1, 1, 2, 1, 2],
"type": "float32"
},
{
"data": [-0.125, -0.25, -0.375, 0.5, 0.625, -0.75, -0.875, -1],
"dims": [2, 1, 2, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0.0625, 0],
"dims": [1, 2, 1, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "fused conv3d with clip",
"operator": "FusedConv",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "activation", "data": "Clip", "type": "string" },
{ "name": "activation_params", "data": [1.0, 3.0], "type": "floats" },
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
{ "name": "auto_pad", "data": "VALID", "type": "string" },
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.25, 0.5, 0.75, 1],
"dims": [1, 1, 2, 1, 2],
"type": "float32"
},
{
"data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1],
"dims": [2, 1, 2, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2.1875],
"dims": [1, 2, 1, 1, 1],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 95eecda

Please sign in to comment.