From 7f644f2e6519291e99471e5393706b2d89edee52 Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Thu, 2 Jan 2025 10:37:20 -0800 Subject: [PATCH] Accept default padding value for torch.constant_pad_nd (#7469) Summary: xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`. {F1974161274} This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op. Reviewed By: tarun292 Differential Revision: D67756862 --- .../operators/op_static_constant_pad.py | 5 ++- .../test/ops/test_static_constant_pad.py | 34 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/operators/op_static_constant_pad.py b/backends/xnnpack/operators/op_static_constant_pad.py index 7c441ee7c68..3381227a885 100644 --- a/backends/xnnpack/operators/op_static_constant_pad.py +++ b/backends/xnnpack/operators/op_static_constant_pad.py @@ -116,11 +116,14 @@ def define_node( pre_paddings = all_paddings[-2::-2] # even index elements in reverse order post_paddings = all_paddings[::-2] # odd index elements in reverse order + # the padding value, which defaults to 0.0 + padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0 + ser_node = XNode( xnode_union=XNNStaticConstantPad( pre_paddings=pre_paddings, post_paddings=post_paddings, - padding_value=cast(float, node.args[2]), + padding_value=padding_value, input_id=input_id, output_id=output_id, flags=0, diff --git a/backends/xnnpack/test/ops/test_static_constant_pad.py b/backends/xnnpack/test/ops/test_static_constant_pad.py index a0a74e3840b..b1b41afe8cf 100644 --- a/backends/xnnpack/test/ops/test_static_constant_pad.py +++ b/backends/xnnpack/test/ops/test_static_constant_pad.py @@ -114,6 +114,40 @@ def test_fp32_static_constant_pad_functional(self): ) self._test_static_constant_pad_functional(inputs) + def test_constant_pad_nd(self): + class ConstantPad(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + pad_6 = (1, 2, 3, 4, 5, 6) + pad_4 = (1, 2, 3, 4) + pad_2 = (1, 2) + a = torch.constant_pad_nd(input=x, pad=pad_6) + b = torch.constant_pad_nd(input=y, pad=pad_4) + c = torch.constant_pad_nd(input=z, pad=pad_2) + + return (a + a, b + b, c + c) + + inputs = ( + torch.randn(size=(5, 4, 3, 2)), + torch.randn(size=(5, 3, 2)), + torch.randn(size=(4, 3)), + ) + ( + Tester(ConstantPad(), inputs) + .export() + .check_count({"torch.ops.aten.constant_pad_nd.default": 3}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + ["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_static_constant_pad_functional(self): class Pad(torch.nn.Module): def __init__(self):