Skip to content

Commit

Permalink
[js/webgpu] Fix max pool shape end with 0
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Aug 11, 2024
1 parent 154084e commit 72e7646
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 28 deletions.
67 changes: 67 additions & 0 deletions js/web/test/data/ops/max-pool.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
[
{
"name": "MaxPool",
"operator": "MaxPool",
"attributes": [
{ "name": "kernel_shape", "data": [3], "type": "ints" },
{ "name": "dilations", "data": [1], "type": "ints" }
],
"cases": [
{
"name": "T[3,5,5] T[3,5,3]",
"inputs": [
{
"data": [
1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238,
-0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334,
0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876,
0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989,
-2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197, 2.269754648208618, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100,
100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100
],
"dims": [3, 5, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [
1.764052391052246, 2.2408931255340576, 2.2408931255340576, 0.9500884413719177, 0.9500884413719177,
0.4105985164642334, 1.4542734622955322, 1.4542734622955322, 0.7610377073287964, 1.4940791130065918,
1.4940791130065918, 0.3130677044391632, 0.8644362092018127, 0.8644362092018127, 2.269754648208618, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
100, 100
],
"dims": [3, 5, 3],
"type": "float32"
}
]
}
]
},
{
"name": "MaxPool",
"operator": "MaxPool",
"attributes": [{ "name": "kernel_shape", "data": [3], "type": "ints" }],
"cases": [
{
"name": "T[1,1,5] T[1,1,3]",
"inputs": [
{
"data": [1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238],
"dims": [1, 1, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [1.764052391052246, 2.2408931255340576, 2.2408931255340576],
"dims": [1, 1, 3],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@
"matmul.jsonc",
"matmulnbits.jsonc",
"matmul-broadcast.jsonc",
"max-pool.jsonc",
"mul.jsonc",
"mul_int32.jsonc",
"multihead-attention.jsonc",
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/js/operators/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ConvBase : public JsKernel {
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);

// currently only support Conv 1D/2D. TODO: support Conv3D and other
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
"format" : $11 ? "NHWC" : "NCHW",
"auto_pad" : $1,
Expand All @@ -65,8 +64,8 @@ class ConvBase : public JsKernel {
JSEP_HEAP32_INDEX_START(dilations),
JSEP_HEAP32_INDEX_END(dilations),
static_cast<int32_t>(conv_attrs_.group),
JSEP_HEAP32_INDEX_START(kernel_shape),
JSEP_HEAP32_INDEX_END(kernel_shape),
JSEP_HEAP32_INDEX_START(kernel_shapes),
JSEP_HEAP32_INDEX_END(kernel_shapes),
JSEP_HEAP32_INDEX_START(local_pads),
JSEP_HEAP32_INDEX_END(local_pads),
JSEP_HEAP32_INDEX_START(strides),
Expand Down
61 changes: 36 additions & 25 deletions onnxruntime/core/providers/js/operators/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,45 @@
namespace onnxruntime {
namespace js {

#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \
"format" : $15 ? "NHWC" : "NCHW", \
"auto_pad" : $1, \
"ceil_mode" : $2, \
"count_include_pad" : $3, \
"storage_order" : $4, \
"dilations" : [ $5, $6 ], \
"kernel_shape" : [ $7, $8 ], \
"pads" : [ $9, $10, $11, $12 ], \
"strides" : [ $13, $14 ] \
#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \
"format" : $13 ? "NHWC" : "NCHW", \
"auto_pad" : $1, \
"ceil_mode" : $2, \
"count_include_pad" : $3, \
"storage_order" : $4, \
"dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \
"kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \
"pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \
"strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \
})

#define POOL_ATTRIBUTES_PARAM_LIST \
static_cast<int32_t>(pool_attrs_.auto_pad), \
static_cast<int32_t>(pool_attrs_.ceil_mode), \
static_cast<int32_t>(pool_attrs_.count_include_pad), \
static_cast<int32_t>(pool_attrs_.storage_order), \
static_cast<int32_t>(pool_attrs_.dilations.size() > 0 ? pool_attrs_.dilations[0] : 0), \
static_cast<int32_t>(pool_attrs_.dilations.size() > 1 ? pool_attrs_.dilations[1] : 0), \
static_cast<int32_t>(pool_attrs_.kernel_shape.size() > 0 ? pool_attrs_.kernel_shape[0] : 0), \
static_cast<int32_t>(pool_attrs_.kernel_shape.size() > 1 ? pool_attrs_.kernel_shape[1] : 0), \
static_cast<int32_t>(pool_attrs_.pads.size() > 0 ? pool_attrs_.pads[0] : 0), \
static_cast<int32_t>(pool_attrs_.pads.size() > 1 ? pool_attrs_.pads[1] : 0), \
static_cast<int32_t>(pool_attrs_.pads.size() > 2 ? pool_attrs_.pads[2] : 0), \
static_cast<int32_t>(pool_attrs_.pads.size() > 3 ? pool_attrs_.pads[3] : 0), \
static_cast<int32_t>(pool_attrs_.strides.size() > 0 ? pool_attrs_.strides[0] : 0), \
static_cast<int32_t>(pool_attrs_.strides.size() > 1 ? pool_attrs_.strides[1] : 0), \
#define POOL_ATTRIBUTES_PARAM_LIST \
static_cast<int32_t>(pool_attrs_.auto_pad), \
static_cast<int32_t>(pool_attrs_.ceil_mode), \
static_cast<int32_t>(pool_attrs_.count_include_pad), \
static_cast<int32_t>(pool_attrs_.storage_order), \
JSEP_HEAP32_INDEX_START(dilations), \
JSEP_HEAP32_INDEX_END(dilations), \
JSEP_HEAP32_INDEX_START(kernel_shapes), \
JSEP_HEAP32_INDEX_END(kernel_shapes), \
JSEP_HEAP32_INDEX_START(pads), \
JSEP_HEAP32_INDEX_END(pads), \
JSEP_HEAP32_INDEX_START(strides), \
JSEP_HEAP32_INDEX_END(strides), \
static_cast<int32_t>(is_channels_last)

#define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"})
#define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast<int32_t>(is_channels_last)

template <typename Type>
inline const std::vector<Type> CastTensorShapeVector(const TensorShapeVector& shape) {
std::vector<Type> castedShapes(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) {
castedShapes[i] = gsl::narrow_cast<Type>(shape[i]);
}
return castedShapes;
}

template <typename PoolType, bool is_channels_last>
class Pool : public JsKernel, public PoolBase {
public:
Expand All @@ -54,6 +61,10 @@ class Pool : public JsKernel, public PoolBase {
// TODO: GlobalLpPool
}
} else {
auto kernel_shapes{CastTensorShapeVector<int32_t>(pool_attrs_.kernel_shape)};
auto strides{CastTensorShapeVector<int32_t>(pool_attrs_.strides)};
auto dilations{CastTensorShapeVector<int32_t>(pool_attrs_.dilations)};
auto pads{CastTensorShapeVector<int32_t>(pool_attrs_.pads)};
if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) {
JSEP_INIT_KERNEL_ATTRIBUTE(AveragePool, POOL_ATTRIBUTES_JS_OBJ_MAPPING, POOL_ATTRIBUTES_PARAM_LIST);
} else if constexpr (PoolType::type == onnxruntime::PoolType::kMaxPool) {
Expand Down

0 comments on commit 72e7646

Please sign in to comment.