Skip to content

Commit

Permalink
[js/webgpu] Fix max pool shape end with 0
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Jul 30, 2024
1 parent 5d78b9a commit 12eafaf
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
10 changes: 8 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
}
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
const kernelShape = attributes.kernelShape.slice();
const strides = attributes.strides.slice();
let kernelShape = attributes.kernelShape.slice();
if (kernelShape[kernelShape.length - 1] === 0) {
kernelShape.pop();
}
let strides = attributes.strides.slice();
if (strides[strides.length - 1] === 0) {
strides.pop();
}
const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
const pads = attributes.pads.slice();
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);
Expand Down
64 changes: 64 additions & 0 deletions js/web/test/data/ops/max-pool.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[
{
"name": "MaxPool",
"operator": "MaxPool",
"attributes": [{ "name": "kernel_shape", "data": [3], "type": "ints" }],
"cases": [
{
"name": "T[1,3,5,5] T[1,3,1,1]",
"inputs": [
{
"data": [
1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238,
-0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334,
0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876,
0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989,
-2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197, 2.269754648208618, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100,
100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100
],
"dims": [3, 5, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [
1.764052391052246, 2.2408931255340576, 2.2408931255340576, 0.9500884413719177, 0.9500884413719177,
0.4105985164642334, 1.4542734622955322, 1.4542734622955322, 0.7610377073287964, 1.4940791130065918,
1.4940791130065918, 0.3130677044391632, 0.8644362092018127, 0.8644362092018127, 2.269754648208618, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
100, 100
],
"dims": [3, 5, 3],
"type": "float32"
}
]
}
]
},
{
"name": "MaxPool",
"operator": "MaxPool",
"attributes": [{ "name": "kernel_shape", "data": [3], "type": "ints" }],
"cases": [
{
"name": "T[1,3,5,5] T[1,3,1,1]",
"inputs": [
{
"data": [1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238],
"dims": [1, 1, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [1.764052391052246, 2.2408931255340576, 2.2408931255340576],
"dims": [1, 1, 3],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,7 @@
"matmul.jsonc",
"matmulnbits.jsonc",
"matmul-broadcast.jsonc",
"max-pool.jsonc",
"mul.jsonc",
"mul_int32.jsonc",
"multihead-attention.jsonc",
Expand Down

0 comments on commit 12eafaf

Please sign in to comment.