diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index bd42710d7b..05f6095c37 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -110,3 +110,14 @@ python_library( "//executorch/backends/arm/operators:node_visitor", ], ) + +python_library( + name = "arm_model_evaluator", + src = [ + "util/arm_model_evaluator.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + ] +) diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index c133ce8003..297047963c 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -8,6 +8,7 @@ from . import ( # noqa mean_dim_support, right_shift_support, + to_copy_support, tosa_supported_operators, var_correction_support, ) diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py new file mode 100644 index 0000000000..9bba274804 --- /dev/null +++ b/backends/arm/operator_support/to_copy_support.py @@ -0,0 +1,120 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import torch + +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) + + +@register_tosa_support_check +class ToCopySupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten._to_copy.default] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] + + @staticmethod + def _merge_supported_types( + dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict + ) -> SupportedTypeDict: + merged_dtypes = dtypes1 + for k, v in dtypes2.items(): + merged_dtypes[k] = merged_dtypes.get(k, []) + v + return merged_dtypes + + SUPPORTED_INT_TYPES: SupportedTypeDict = { + torch.bool: [torch.int8, torch.int16, torch.int32], + torch.int8: [torch.bool, torch.int16, torch.int32], + torch.int16: [torch.bool, torch.int8, torch.int32], + torch.int32: [torch.bool, torch.int8, torch.int16], + } + SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { + torch.int8: [torch.float16, torch.bfloat16, torch.float32], + torch.int16: [torch.float16, torch.bfloat16, torch.float32], + torch.int32: [torch.float16, torch.bfloat16, torch.float32], + torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.float32: [ + torch.int8, + torch.int16, + torch.int32, + torch.bfloat16, + torch.float16, + ], + } + ALL_SUPPORTED_TYPES = _merge_supported_types( + SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES + ) + POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + assert node.target in self.targets + + if tosa_spec not in self.tosa_specs: + return False + + assert tosa_spec.support_integer() + supported_dtypes = ( + self.ALL_SUPPORTED_TYPES + if tosa_spec.support_float() + else self.SUPPORTED_INT_TYPES + ) + # Take into account possible type conversions + supported_dtypes.update( + (k, supported_dtypes[v]) + for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items() + if v in supported_dtypes + ) + + # Check input type + assert len(node.all_input_nodes) == 1 + input_val = node.all_input_nodes[0].meta["val"] + assert isinstance(input_val, torch._subclasses.FakeTensor) + input_dtype = input_val.dtype + if input_dtype not in supported_dtypes: + logger.info( + f"Input dtype {input_val.dtype} is not supported in " + f"{node.target.name()}." + ) + return False + + # Check output type + output_val = node.meta["val"] + assert isinstance(output_val, torch._subclasses.FakeTensor) + if output_val.dtype not in supported_dtypes[input_dtype]: + logger.info( + f"Output dtype {output_val.dtype} is not supported in " + f"{node.target.name()} for input dtype {input_dtype}. " + f"Supported output types: " + f"{''.join(str(t) for t in supported_dtypes[input_dtype])}" + ) + return False + + # Check memory format + if "memory_format" in node.kwargs: + if node.kwargs["memory_format"] in (torch.preserve_format,): + logger.info( + f"Argument 'memory_format' is not supported for " + f"{node.target.name()} right now." + ) + return False + + return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a5c2dd8dc5..8c4aa85e57 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -36,6 +36,7 @@ op_sub, op_sum, op_tanh, + op_to_copy, op_transpose, op_unsqueeze, op_upsample_nearest2d, diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py new file mode 100644 index 0000000000..15077d6df7 --- /dev/null +++ b/backends/arm/operators/op_to_copy.py @@ -0,0 +1,43 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +import tosa.Op as TosaOp + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg + + +@register_node_visitor +class ToCopyVisitor(NodeVisitor): + """ + Implement the type cast functionality of _to_copy. + + Other features like setting of the memory_format or moving a tensor to a + different device are not supported. + + Also note that the node should not be quantized. + """ + + target = "aten._to_copy.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert not is_quant_node, "Casting of quantized values is not supported." + assert inputs + tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/test/TARGETS b/backends/arm/test/TARGETS new file mode 100644 index 0000000000..ef092c5503 --- /dev/null +++ b/backends/arm/test/TARGETS @@ -0,0 +1,23 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "common", + srcs = ["common.py"], + deps = [ + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/arm:arm_backend", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + ] +) + +python_library( + name = "runner_utils", + srcs = ["runner_utils.py"], + deps = [ + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/arm:arm_backend", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + ] +) diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index 824ec46372..523a90cdc8 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -22,8 +22,8 @@ class TestBMM(unittest.TestCase): class BMM(torch.nn.Module): test_parameters = [ - (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), (torch.rand(2, 1, 1), torch.rand(2, 1, 1)), + (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), (torch.ones(1, 55, 3), torch.ones(1, 3, 44)), (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)), (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)), @@ -147,32 +147,37 @@ def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor): @parameterized.expand(BMM.test_parameters) @unittest.expectedFailure - def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) self._test_bmm_ethosu_BI_pipeline( self.BMM(), common.get_u55_compile_spec(), test_data ) - @parameterized.expand(BMM.test_parameters) - @common.expectedFailureOnFVP + @parameterized.expand(BMM.test_parameters[:1]) def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) self._test_bmm_ethosu_BI_pipeline( self.BMM(), common.get_u85_compile_spec(), test_data ) + @parameterized.expand(BMM.test_parameters[1:]) + @common.expectedFailureOnFVP + def test_bmm_u85_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_ethosu_BI_pipeline( + self.BMM(), common.get_u85_compile_spec(), test_data + ) + # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy @parameterized.expand(BMMSingleInput.test_parameters) @unittest.expectedFailure - def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor): + def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_ethosu_BI_pipeline( self.BMMSingleInput(), common.get_u55_compile_spec(), test_data ) - # Numerical issues on FVP, MLETORCH 534 @parameterized.expand(BMMSingleInput.test_parameters) - @common.expectedFailureOnFVP def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 7555fff720..001c4a2bd5 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -275,8 +275,6 @@ def test_conv_meandim_u55_BI(self): model.get_inputs(), ) - # Numerical Issues on FVP, MLETORCH-520 - @common.expectedFailureOnFVP def test_conv_meandim_u85_BI(self): model = ComboConv2dMeandim() self._test_conv_combo_ethos_BI_pipeline( diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 28cb9ac844..d753245f43 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -156,6 +156,19 @@ ("two_dw_conv2d", two_dw_conv2d), ] +testsuite_conv2d_u85 = [ + ("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1), + ("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1), + ("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1), + ("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias), +] + +testsuite_conv2d_u85_xfails = [ + ("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3), + ("two_dw_conv2d", two_dw_conv2d), +] + + testsuite_conv1d = [ ("2_1x6x4_gp6_st1", dw_conv1d_2_1x6x4_gp6_st1), ("two_dw_conv1d", two_dw_conv1d), @@ -247,7 +260,7 @@ def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module): ) # Works @parameterized.expand(testsuite_conv2d, skip_on_empty=True) - @common.expectedFailureOnFVP + @unittest.expectedFailure def test_dw_conv2d_u55_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): @@ -274,10 +287,8 @@ def test_dw_conv1d_u55_BI( model.get_inputs(), ) - # All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520 - @parameterized.expand(testsuite_conv1d[:-2] + testsuite_conv2d) - @common.expectedFailureOnFVP - def test_dw_conv_u85_BI_xfails( + @parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85) + def test_dw_conv_u85_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): self._test_dw_conv_ethos_BI_pipeline( @@ -288,8 +299,10 @@ def test_dw_conv_u85_BI_xfails( model.get_inputs(), ) - @parameterized.expand(testsuite_conv1d[-2:]) - def test_dw_conv_u85_BI( + # All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520 + @parameterized.expand(testsuite_conv2d_u85_xfails) + @common.expectedFailureOnFVP + def test_dw_conv_u85_BI_xfails( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): self._test_dw_conv_ethos_BI_pipeline( diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index b3815f3e7c..27febd714e 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -26,18 +26,18 @@ torch.ones(5), None, ), - ( - "op_div_rank1_rand", - torch.rand(5) * 5, - torch.rand(5) * 5, - None, - ), ( "op_div_rank1_negative_ones", torch.ones(5) * (-1), torch.ones(5) * (-1), None, ), + ( + "op_div_rank1_rand", + torch.rand(5) * 5, + torch.rand(5) * 5, + None, + ), ( "op_div_rank4_ones", torch.ones(5, 10, 25, 20), @@ -183,9 +183,7 @@ def test_div_tosa_BI( test_data = (input_, other_) self._test_div_tosa_BI_pipeline(self.Div(), test_data) - # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite) - @common.expectedFailureOnFVP + @parameterized.expand(test_data_suite[:2]) def test_div_u55_BI( self, test_name: str, @@ -199,8 +197,21 @@ def test_div_u55_BI( ) # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite) + @parameterized.expand(test_data_suite[2:]) @common.expectedFailureOnFVP + def test_div_u55_BI_xfails( + self, + test_name: str, + input_: Union[torch.Tensor, torch.types.Number], + other_: Union[torch.Tensor, torch.types.Number], + rounding_mode: Optional[str] = None, + ): + test_data = (input_, other_) + self._test_div_ethos_BI_pipeline( + self.Div(), common.get_u55_compile_spec(), test_data + ) + + @parameterized.expand(test_data_suite[:2]) def test_div_u85_BI( self, test_name: str, @@ -212,3 +223,18 @@ def test_div_u85_BI( self._test_div_ethos_BI_pipeline( self.Div(), common.get_u85_compile_spec(), test_data ) + + # Numerical issues on FVP likely due to mul op, MLETORCH-521 + @parameterized.expand(test_data_suite[2:]) + @common.expectedFailureOnFVP + def test_div_u85_BI_xfails( + self, + test_name: str, + input_: Union[torch.Tensor, torch.types.Number], + other_: Union[torch.Tensor, torch.types.Number], + rounding_mode: Optional[str] = None, + ): + test_data = (input_, other_) + self._test_div_ethos_BI_pipeline( + self.Div(), common.get_u85_compile_spec(), test_data + ) diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 0b06044a59..e84dd4ee58 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -158,7 +158,7 @@ def test_layer_norm_tosa_BI( # Numerical issues on FVP likely due to mul op, MLETORCH-521 # Skip tests that require transposes. @parameterized.expand(test_data_suite[:-2]) - @common.expectedFailureOnFVP + @unittest.expectedFailure def test_layer_norm_u55_BI( self, test_name: str, @@ -170,9 +170,8 @@ def test_layer_norm_u55_BI( ) # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite[:-1]) - @common.expectedFailureOnFVP - def test_layer_norm_u85_BI_fvp_xfails( + @parameterized.expand(test_data_suite[:-2]) + def test_layer_norm_u85_BI_fvp( self, test_name: str, test_data: torch.Tensor, @@ -182,7 +181,7 @@ def test_layer_norm_u85_BI_fvp_xfails( self.LayerNorm(*model_params), common.get_u85_compile_spec(), (test_data,) ) - @parameterized.expand(test_data_suite[-1:]) + @parameterized.expand(test_data_suite[-2:]) @unittest.skip # Flaky def test_layer_norm_u85_BI( self, diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 5d84fa127f..910384e0a0 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -17,14 +17,29 @@ test_data_suite = [ # (test_name, test_data, dim) - ("zeros", torch.zeros(10, 10, 10, 10), 0), - ("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4), + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), ("ones", torch.ones(10, 10), 1), - ("rand_neg_dim", torch.rand(10, 10, 10), -1), - ("rand", torch.rand(10, 10, 10, 10), 2), - ("rand_neg_dim", torch.rand(10, 10, 2, 3), -2), - ("randn", torch.randn(10, 10, 5, 10), 3), - ("randn_neg_dim", torch.randn(1, 10, 10, 10), -3), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), + ("randn", torch.randn(10, 10, 10, 10), 3), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), +] +test_data_suite_u55 = [ + # (test_name, test_data, dim) + ("ones", torch.ones(10, 10), 1), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), +] + +test_data_suite_u55_xfails = [ + # (test_name, test_data, dim) + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), + ("randn", torch.randn(10, 10, 10, 10), 3), ] @@ -135,7 +150,7 @@ def test_logsoftmax_tosa_BI( ): self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) - @parameterized.expand(test_data_suite) + @parameterized.expand(test_data_suite_u55) def test_logsoftmax_tosa_u55_BI( self, test_name: str, @@ -146,6 +161,19 @@ def test_logsoftmax_tosa_u55_BI( self.LogSoftmax(dim=dim), (test_data,) ) + # Expected to fail as this is not supported on u55. + @parameterized.expand(test_data_suite_u55_xfails) + @unittest.expectedFailure + def test_logsoftmax_tosa_u55_BI_xfails( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_logsoftmax_tosa_u55_BI_pipeline( + self.LogSoftmax(dim=dim), (test_data,) + ) + @parameterized.expand(test_data_suite) def test_logsoftmax_tosa_u85_BI( self, @@ -153,6 +181,6 @@ def test_logsoftmax_tosa_u85_BI( test_data: torch.Tensor, dim: int, ): - self._test_logsoftmax_tosa_u55_BI_pipeline( + self._test_logsoftmax_tosa_u85_BI_pipeline( self.LogSoftmax(dim=dim), (test_data,) ) diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index e8320cf1df..3cb8c5f815 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -269,8 +269,10 @@ def test_meandim_tosa_BI( ): self._test_meandim_tosa_BI_pipeline(self.MeanDim(dim, keepdim), (test_data,)) + # Expected to fail as this is not supported on u55. @parameterized.expand(MeanDim.test_data_suite) - def test_meandim_tosa_u55_BI( + @unittest.expectedFailure + def test_meandim_tosa_u55_BI_xfails( self, test_name: str, test_data: torch.Tensor, diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index 8f0321ea5f..6d6922628e 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -152,9 +152,7 @@ def test_mul_tosa_BI( test_data = (input_, other_) self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) - # Numerical issues on FVP, MLETORCH-521 @parameterized.expand(test_data_sute) - @common.expectedFailureOnFVP def test_mul_u55_BI( self, test_name: str, @@ -166,10 +164,7 @@ def test_mul_u55_BI( common.get_u55_compile_spec(), self.Mul(), test_data ) - # Numerical issues on FVP, MLETORCH-521 - # test_data_sute[0] works on U85 - @parameterized.expand(test_data_sute[1:]) - @common.expectedFailureOnFVP + @parameterized.expand(test_data_sute) def test_mul_u85_BI( self, test_name: str, diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index cd3dd72f60..455b484b94 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -153,9 +153,21 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): .run_method_and_compare_outputs(inputs=test_data) ) - # Most MI tests fail, just show one working for now. - @parameterized.expand((tensor_scalar_tests[6],)) + @parameterized.expand(tensor_scalar_tests) def test_MI(self, test_name: str, op: torch.nn.Module, x, y): + expected_exception = None + if any(token in test_name for token in ("Sub_int", "Sub__int")): + expected_exception = RuntimeError + elif test_name.endswith("_st"): + expected_exception = AttributeError + + if expected_exception: + with self.assertRaises( + expected_exception, msg=f"Test {test_name} is expected to fail." + ): + self._test_add_tosa_MI_pipeline(op, (x, y)) + return + self._test_add_tosa_MI_pipeline(op, (x, y)) # op(Scalar float, tensor) works if the scalar is constant. diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index f883d6b8de..30215b47f3 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -28,6 +28,22 @@ ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), ] +test_data_suite_u55 = [ + # (test_name, test_data, dim) + ("ones", torch.ones(10, 10), 1), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), +] + +test_data_suite_u55_xfails = [ + # (test_name, test_data, dim) + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), + ("randn", torch.randn(10, 10, 10, 10), 3), +] + class TestSoftmax(unittest.TestCase): """Tests softmax.""" @@ -136,7 +152,7 @@ def test_softmax_tosa_BI( ): self._test_softmax_tosa_BI_pipeline(self.Softmax(dim=dim), (test_data,)) - @parameterized.expand(test_data_suite) + @parameterized.expand(test_data_suite_u55) def test_softmax_tosa_u55_BI( self, test_name: str, @@ -145,6 +161,17 @@ def test_softmax_tosa_u55_BI( ): self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,)) + # Expected to fail as this is not supported on u55. + @parameterized.expand(test_data_suite_u55_xfails) + @unittest.expectedFailure + def test_softmax_tosa_u55_BI_xfails( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,)) + @parameterized.expand(test_data_suite) def test_softmax_tosa_u85_BI( self, diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 9cd63b0a22..111517afbb 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -35,6 +35,18 @@ class Sum(torch.nn.Module): ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), ] + test_parameters_u55: list[Tuple[exampledata_t]] = [ + ((torch.rand(10), 0, True),), + ((torch.rand(10, 10), 1, False),), + ((torch.rand(1, 2, 3, 4), 3, True),), + ] + + test_parameters_u55_xfails: list[Tuple[exampledata_t]] = [ + ((torch.rand(10, 10, 10), [-3, 1], True),), + ((torch.rand(2, 1, 5, 8), 1, False),), + ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), + ] + def forward(self, x: torch.Tensor, dim: int, keepdim: bool): return x.sum(dim=dim, keepdim=keepdim) @@ -112,7 +124,7 @@ def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]): def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]): self._test_sum_tosa_BI_pipeline(self.Sum(), test_data) - @parameterized.expand(Sum.test_parameters) + @parameterized.expand(Sum.test_parameters_u55) def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): self._test_sum_ethosu_BI_pipeline( self.Sum(), @@ -120,6 +132,16 @@ def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): common.get_u55_compile_spec(permute_memory_to_nhwc=False), ) + # Expected to fail as this is not supported on u55. + @parameterized.expand(Sum.test_parameters_u55_xfails) + @unittest.expectedFailure + def test_sum_u55_BI_xfails(self, test_data: tuple[exampledata_t]): + self._test_sum_ethosu_BI_pipeline( + self.Sum(), + test_data, + common.get_u55_compile_spec(permute_memory_to_nhwc=False), + ) + @parameterized.expand(Sum.test_parameters) def test_sum_u85_BI(self, test_data: tuple[exampledata_t]): self._test_sum_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py new file mode 100644 index 0000000000..8499512e10 --- /dev/null +++ b/backends/arm/test/ops/test_to_copy.py @@ -0,0 +1,70 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Tests the _to_copy op which is interpreted as a cast for our purposes. +# + +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from parameterized import parameterized + + +class Cast(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor): + return x.to(dtype=self.target_dtype) + + +class TestToCopy(unittest.TestCase): + """ + Tests the _to_copy operation. + + Only test unquantized graphs as explicit casting of dtypes messes with the + quantization. + + Note: This is also covered by test_scalars.py. + """ + + _TO_COPY_TEST_DATA = ( + (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.float32), + (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.float16), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.float32), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.int32), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int8), + ) + + def _test_to_copy_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: torch.Tensor + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .dump_artifact() + .check_count({"torch.ops.aten._to_copy.default": 1}) + .to_edge() + .dump_artifact() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + @parameterized.expand(_TO_COPY_TEST_DATA) + def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_dtype): + self._test_to_copy_tosa_MI_pipeline(Cast(new_dtype), (test_tensor,)) diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 3a1285e6da..06671848cc 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -50,6 +50,16 @@ class VarDim(torch.nn.Module): (torch.rand(1, 50, 10, 20), -1, True, True), ] + test_parameters_u55 = [ + (torch.randn(1, 50, 10, 20), 1, True, False), + (torch.randn(1, 30, 15, 20), -3, True, True), + ] + + test_parameters_u55_xfails = [ + (torch.rand(1, 50, 10), -2, True, False), + (torch.rand(1, 50, 10, 20), -1, True, True), + ] + def forward( self, x: torch.Tensor, @@ -148,8 +158,10 @@ def test_var_tosa_MI(self, test_tensor: torch.Tensor, keepdim, correction): def test_var_tosa_BI(self, test_tensor: torch.Tensor, keepdim, correction): self._test_var_tosa_BI_pipeline(self.Var(), (test_tensor, keepdim, correction)) + # Expected to fail as this is not supported on u55. @parameterized.expand(Var.test_parameters) - def test_var_u55_BI(self, test_tensor: torch.Tensor, keepdim, correction): + @unittest.expectedFailure + def test_var_u55_BI_xfails(self, test_tensor: torch.Tensor, keepdim, correction): self._test_var_ethosu_BI_pipeline( self.Var(), common.get_u55_compile_spec(), @@ -176,7 +188,7 @@ def test_var_dim_tosa_BI(self, test_tensor: torch.Tensor, dim, keepdim, correcti self.VarDim(), (test_tensor, dim, keepdim, correction) ) - @parameterized.expand(VarDim.test_parameters) + @parameterized.expand(VarDim.test_parameters_u55) def test_var_dim_u55_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): self._test_var_ethosu_BI_pipeline( self.VarDim(), @@ -184,6 +196,18 @@ def test_var_dim_u55_BI(self, test_tensor: torch.Tensor, dim, keepdim, correctio (test_tensor, dim, keepdim, correction), ) + # Expected to fail as this is not supported on u55. + @parameterized.expand(VarDim.test_parameters_u55_xfails) + @unittest.expectedFailure + def test_var_dim_u55_BI_xfails( + self, test_tensor: torch.Tensor, dim, keepdim, correction + ): + self._test_var_ethosu_BI_pipeline( + self.VarDim(), + common.get_u55_compile_spec(), + (test_tensor, dim, keepdim, correction), + ) + @parameterized.expand(VarDim.test_parameters) def test_var_dim_u85_BI(self, test_tensor: torch.Tensor, dim, keepdim, correction): self._test_var_ethosu_BI_pipeline( @@ -208,8 +232,10 @@ def test_var_correction_tosa_BI( self.VarCorrection(), (test_tensor, dim, keepdim, correction) ) + # Expected to fail as this is not supported on u55. @parameterized.expand(VarCorrection.test_parameters) - def test_var_correction_u55_BI( + @unittest.expectedFailure + def test_var_correction_u55_BI_xfails( self, test_tensor: torch.Tensor, dim, keepdim, correction ): self._test_var_ethosu_BI_pipeline( diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index ddd5fd6b0b..a16d947dd6 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -263,7 +263,7 @@ def get_compile_spec( target, system_config="Ethos_U55_High_End_Embedded", memory_mode="Shared_Sram", - extra_flags="--debug-force-regor --output-format=raw", + extra_flags="--debug-force-regor --output-format=raw --verbose-operators --verbose-cycle-estimate", ) .set_permute_memory_format(True) .set_quantize_io(True) @@ -276,7 +276,7 @@ def get_compile_spec( target, system_config="Ethos_U85_SYS_DRAM_Mid", memory_mode="Shared_Sram", - extra_flags="--output-format=raw", + extra_flags="--output-format=raw --verbose-operators --verbose-cycle-estimate", ) .set_permute_memory_format(True) .set_quantize_io(True) diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 0e5fa9db34..cbc96c4b11 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -213,9 +213,9 @@ function build_executorch_runner() { cmake --build ${executor_runner_path}/cmake-out --parallel -- arm_executor_runner echo "[${FUNCNAME[0]}] Generated baremetal elf file:" find ${executor_runner_path}/cmake-out -name "arm_executor_runner" - echo "executable_text: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec size {} \; | grep -v filename | awk '{print $1}') bytes" - echo "executable_data: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec size {} \; | grep -v filename | awk '{print $2}') bytes" - echo "executable_bss: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec size {} \; | grep -v filename | awk '{print $3}') bytes" + echo "executable_text: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec arm-none-eabi-size {} \; | grep -v filename | awk '{print $1}') bytes" + echo "executable_data: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec arm-none-eabi-size {} \; | grep -v filename | awk '{print $2}') bytes" + echo "executable_bss: $(find ${executor_runner_path}/cmake-out -name arm_executor_runner -exec arm-none-eabi-size {} \; | grep -v filename | awk '{print $3}') bytes" } # Execute the executor_runner on FVP Simulator diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 583237729d..6f619ef058 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -55,9 +55,9 @@ if [[ "${ARCH}" == "x86_64" ]]; then corstone320_md5_checksum="3deb3c68f9b2d145833f15374203514d" # toochain - toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi.tar.xz" - toolchain_dir="arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi" - toolchain_md5_checksum="00ebb1b70b1f88906c61206457eacb61" + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi" + toolchain_md5_checksum="0601a9588bc5b9c99ad2b56133b7f118" elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then # FVPs corstone300_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64_armv8l.tgz?rev=9cc6e9a32bb947ca9b21fa162144cb01&hash=7657A4CF27D42E892E3F08D452AAB073" @@ -70,13 +70,13 @@ elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then # toochain if [[ "${OS}" == "Darwin" ]]; then - toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-darwin-arm64-arm-none-eabi.tar.xz" - toolchain_dir="arm-gnu-toolchain-12.3.rel1-darwin-arm64-arm-none-eabi" - toolchain_md5_checksum="53d034e9423e7f470acc5ed2a066758e" + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-darwin-arm64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-13.3.rel1-darwin-arm64-arm-none-eabi" + toolchain_md5_checksum="f1c18320bb3121fa89dca11399273f4e" elif [[ "${OS}" == "Linux" ]]; then - toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi.tar.xz" - toolchain_dir="arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi" - toolchain_md5_checksum="02c9b0d3bb1110575877d8eee1f223f2" + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-aarch64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-13.3.rel1-aarch64-arm-none-eabi" + toolchain_md5_checksum="303102d97b877ebbeb36b3158994b218" fi else echo "[main] Error: only x86-64 & aarch64/arm64 architecture is supported for now!"; exit 1; @@ -89,7 +89,11 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5" - + +# vela +vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela" +vela_rev="a08fc18780827b5fefc814dd0162ee6317ce0ae7" + ######## ### Mandatory user args ######## @@ -174,15 +178,15 @@ function setup_fvp() { function setup_toolchain() { # Download and install the arm-none-eabi toolchain cd "${root_dir}" - if [[ ! -e gcc.tar.xz ]]; then + if [[ ! -e "${toolchain_dir}.tar.xz" ]]; then echo "[${FUNCNAME[0]}] Downloading toolchain ..." - curl --output gcc.tar.xz "${toolchain_url}" - verify_md5 ${toolchain_md5_checksum} gcc.tar.xz + curl --output "${toolchain_dir}.tar.xz" "${toolchain_url}" + verify_md5 ${toolchain_md5_checksum} "${toolchain_dir}.tar.xz" fi echo "[${FUNCNAME[0]}] Installing toolchain ..." rm -rf "${toolchain_dir}" - tar xf gcc.tar.xz + tar xf "${toolchain_dir}.tar.xz" toolchain_bin_path="$(cd ${toolchain_dir}/bin && pwd)" export PATH=${PATH}:${toolchain_bin_path} hash arm-none-eabi-gcc @@ -198,6 +202,7 @@ function setup_ethos_u() { cd ethos-u git reset --hard ${ethos_u_base_rev} python3 ./fetch_externals.py -c ${ethos_u_base_rev}.json fetch + pip install pyelftools echo "[${FUNCNAME[0]}] Done @ $(git describe --all --long 3> /dev/null) in ${root_dir}/ethos-u dir." } @@ -259,9 +264,9 @@ function setup_vela() { # cd "${root_dir}" if [[ ! -e ethos-u-vela ]]; then - git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela + git clone ${vela_repo_url} repo_dir="${root_dir}/ethos-u-vela" - base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f + base_rev=${vela_rev} patch_repo fi cd "${root_dir}/ethos-u-vela" diff --git a/extension/pytree/aten_util/targets.bzl b/extension/pytree/aten_util/targets.bzl index 5ba7e90596..e179308020 100644 --- a/extension/pytree/aten_util/targets.bzl +++ b/extension/pytree/aten_util/targets.bzl @@ -20,7 +20,13 @@ def define_common_targets(): "//executorch/runtime/platform:platform", ], compiler_flags = ["-Wno-missing-prototypes"], - external_deps = [ - "torch-core-cpp", + fbcode_deps = [ + "//caffe2:ATen-core", + "//caffe2:ATen-cpu", + "//caffe2/c10:c10", + ], + xplat_deps = [ + "//xplat/caffe2:torch_mobile_core", + "//xplat/caffe2/c10:c10", ], )