From a694541229cdf4caf2c1bba29a1a42ffe3b9168f Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Thu, 2 Jan 2025 02:20:40 -0800 Subject: [PATCH] Accept default padding value for torch.constant_pad_nd (#7469) Summary: xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`. {F1974161274} This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op. Differential Revision: D67756862 --- .../operators/op_static_constant_pad.py | 5 ++- .../test/ops/test_static_constant_pad.py | 43 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/operators/op_static_constant_pad.py b/backends/xnnpack/operators/op_static_constant_pad.py index 7c441ee7c68..3381227a885 100644 --- a/backends/xnnpack/operators/op_static_constant_pad.py +++ b/backends/xnnpack/operators/op_static_constant_pad.py @@ -116,11 +116,14 @@ def define_node( pre_paddings = all_paddings[-2::-2] # even index elements in reverse order post_paddings = all_paddings[::-2] # odd index elements in reverse order + # the padding value, which defaults to 0.0 + padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0 + ser_node = XNode( xnode_union=XNNStaticConstantPad( pre_paddings=pre_paddings, post_paddings=post_paddings, - padding_value=cast(float, node.args[2]), + padding_value=padding_value, input_id=input_id, output_id=output_id, flags=0, diff --git a/backends/xnnpack/test/ops/test_static_constant_pad.py b/backends/xnnpack/test/ops/test_static_constant_pad.py index a0a74e3840b..9148e930330 100644 --- a/backends/xnnpack/test/ops/test_static_constant_pad.py +++ b/backends/xnnpack/test/ops/test_static_constant_pad.py @@ -114,6 +114,49 @@ def test_fp32_static_constant_pad_functional(self): ) self._test_static_constant_pad_functional(inputs) + def test_constant_pad_nd(self): + class ConstantPad(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + pad_6 = (1, 2, 3, 4, 5, 6) + pad_4 = (1, 2, 3, 4) + pad_2 = (1, 2) + a = torch.constant_pad_nd( + input=x, + pad=pad_6 + ) + b = torch.constant_pad_nd( + input=y, + pad=pad_4 + ) + c = torch.constant_pad_nd( + input=z, + pad=pad_2 + ) + + return (a + a, b + b, c + c) + + inputs = ( + torch.randn(size=(5, 4, 3, 2)), + torch.randn(size=(5, 3, 2)), + torch.randn(size=(4, 3)), + ) + ( + Tester(ConstantPad(), inputs) + .export() + .check_count({"torch.ops.aten.constant_pad_nd.default": 3}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + ["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_static_constant_pad_functional(self): class Pad(torch.nn.Module): def __init__(self):