Skip to content

Commit

Permalink
[js/webgpu] Fix max pool shape end with 0
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Aug 1, 2024
1 parent 5d78b9a commit 80ac91b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
15 changes: 12 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
}
};

const removeLastZeroOfArray = (arr: number[]): number[] => {
if (arr[arr.length - 1] === 0) {
arr.pop();
}
return arr;
};

const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => {
const isChannelsLast = attributes.format === 'NHWC';
Expand All @@ -31,9 +38,11 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
}
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
const kernelShape = attributes.kernelShape.slice();
const strides = attributes.strides.slice();
const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
let kernelShape = removeLastZeroOfArray(attributes.kernelShape.slice());
let strides = removeLastZeroOfArray(attributes.strides.slice());
let dilations: number[] =
removeLastZeroOfArray(hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : []);

const pads = attributes.pads.slice();
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);

Expand Down
67 changes: 67 additions & 0 deletions js/web/test/data/ops/max-pool.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
[
{
"name": "MaxPool",
"operator": "MaxPool",
"attributes": [
{ "name": "kernel_shape", "data": [3], "type": "ints" },
{ "name": "dilations", "data": [1], "type": "ints" }
],
"cases": [
{
"name": "T[3,5,5] T[3,5,3]",
"inputs": [
{
"data": [
1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238,
-0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334,
0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876,
0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989,
-2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197, 2.269754648208618, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100,
100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100
],
"dims": [3, 5, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [
1.764052391052246, 2.2408931255340576, 2.2408931255340576, 0.9500884413719177, 0.9500884413719177,
0.4105985164642334, 1.4542734622955322, 1.4542734622955322, 0.7610377073287964, 1.4940791130065918,
1.4940791130065918, 0.3130677044391632, 0.8644362092018127, 0.8644362092018127, 2.269754648208618, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
100, 100
],
"dims": [3, 5, 3],
"type": "float32"
}
]
}
]
},
{
"name": "MaxPool",
"operator": "MaxPool",
"attributes": [{ "name": "kernel_shape", "data": [3], "type": "ints" }],
"cases": [
{
"name": "T[1,1,5] T[1,1,3]",
"inputs": [
{
"data": [1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238],
"dims": [1, 1, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [1.764052391052246, 2.2408931255340576, 2.2408931255340576],
"dims": [1, 1, 3],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,7 @@
"matmul.jsonc",
"matmulnbits.jsonc",
"matmul-broadcast.jsonc",
"max-pool.jsonc",
"mul.jsonc",
"mul_int32.jsonc",
"multihead-attention.jsonc",
Expand Down

0 comments on commit 80ac91b

Please sign in to comment.