Skip to content

Commit

Permalink
Accept default padding value for torch.constant_pad_nd (#7469)
Browse files Browse the repository at this point in the history
Summary:

xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`.

{F1974161274}

This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op.

Reviewed By: tarun292

Differential Revision: D67756862
  • Loading branch information
pssrawat authored and facebook-github-bot committed Jan 2, 2025
1 parent 3ef78ee commit 7f644f2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
5 changes: 4 additions & 1 deletion backends/xnnpack/operators/op_static_constant_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ def define_node(
pre_paddings = all_paddings[-2::-2] # even index elements in reverse order
post_paddings = all_paddings[::-2] # odd index elements in reverse order

# the padding value, which defaults to 0.0
padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0

ser_node = XNode(
xnode_union=XNNStaticConstantPad(
pre_paddings=pre_paddings,
post_paddings=post_paddings,
padding_value=cast(float, node.args[2]),
padding_value=padding_value,
input_id=input_id,
output_id=output_id,
flags=0,
Expand Down
34 changes: 34 additions & 0 deletions backends/xnnpack/test/ops/test_static_constant_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,40 @@ def test_fp32_static_constant_pad_functional(self):
)
self._test_static_constant_pad_functional(inputs)

def test_constant_pad_nd(self):
class ConstantPad(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y, z):
pad_6 = (1, 2, 3, 4, 5, 6)
pad_4 = (1, 2, 3, 4)
pad_2 = (1, 2)
a = torch.constant_pad_nd(input=x, pad=pad_6)
b = torch.constant_pad_nd(input=y, pad=pad_4)
c = torch.constant_pad_nd(input=z, pad=pad_2)

return (a + a, b + b, c + c)

inputs = (
torch.randn(size=(5, 4, 3, 2)),
torch.randn(size=(5, 3, 2)),
torch.randn(size=(4, 3)),
)
(
Tester(ConstantPad(), inputs)
.export()
.check_count({"torch.ops.aten.constant_pad_nd.default": 3})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_qs8_static_constant_pad_functional(self):
class Pad(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 7f644f2

Please sign in to comment.