Skip to content

Commit

Permalink
Failing DW test on executorch (pytorch#4929)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4929

example of failing U55 DW test

Reviewed By: hsharma35

Differential Revision: D61862425

fbshipit-source-id: e56c34382d404bad092c97e3c59e96f1ae51a9b3
  • Loading branch information
Eashan Garg authored and facebook-github-bot committed Sep 12, 2024
1 parent 623b7b6 commit 4c61317
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
4 changes: 3 additions & 1 deletion backends/arm/passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
prev_node = node.args[0]
if cast(torch.fx.Node, prev_node).op != "placeholder":
return False
return is_consumer_node_depthwise_conv2d(node)
if is_consumer_node_depthwise_conv2d(node):
consumer_node = list(node.users)[0]
return consumer_node.args[1] == node
elif node.op == "placeholder":
# node is an input, weight or bias node
consumer_node = list(node.users)[0]
Expand Down
20 changes: 16 additions & 4 deletions backends/arm/test/ops/test_depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@
# Fails when enabling CompileSpec.set_quantize_io(True). MLETORCH-191.
testsuite_u55.remove(("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1))

# Add failing test (set_quantize_io=True) temporarily to investigate
testsuite_u55.append(
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1, True)
)


class TestDepthwiseConv2D(unittest.TestCase):
"""Tests Conv2D where groups == in_channels and out_channels = K * in_channels. This
Expand Down Expand Up @@ -173,13 +178,18 @@ def _test_dw_conv2d_tosa_BI_pipeline(
)

def _test_dw_conv2d_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
self,
module: torch.nn.Module,
test_data: Tuple[torch.Tensor],
set_quantize_io: bool = False,
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True),
compile_spec=common.get_u55_compile_spec(
permute_memory_to_nhwc=True, quantize_io=set_quantize_io
),
)
.quantize()
.export()
Expand All @@ -202,5 +212,7 @@ def test_dw_conv2d_tosa_BI(self, test_name, model):

@parameterized.expand(testsuite_u55, skip_on_empty=True)
@unittest.expectedFailure
def test_dw_conv2d_u55_BI(self, test_name, model):
self._test_dw_conv2d_u55_BI_pipeline(model, model.get_inputs())
def test_dw_conv2d_u55_BI(self, test_name, model, set_quantize_io=False):
self._test_dw_conv2d_u55_BI_pipeline(
model, model.get_inputs(), set_quantize_io=set_quantize_io
)

0 comments on commit 4c61317

Please sign in to comment.