From 717584cc8244e90ec6d911e996251b955512d525 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Fri, 13 Dec 2024 11:35:21 +0000 Subject: [PATCH] 2024-12-13 nightly release (3f7eb3b439cac329aa7d7362dc5317f927791058) --- .ci/docker/ubuntu/Dockerfile | 3 - .ci/scripts/build_llama_android.sh | 3 +- .ci/scripts/utils.sh | 9 +- .github/workflows/docker-builds.yml | 4 - .../annotate_channels_last_dim_order_pass.py | 154 +++++++++++++----- backends/arm/test/ops/test_layer_norm.py | 8 +- backends/arm/test/ops/test_to_copy.py | 1 - backends/arm/test/ops/test_view.py | 1 + .../cadence/fusion_g3/operators/op_add.cpp | 124 +++++++++++--- .../cadence/fusion_g3/operators/op_cat.cpp | 17 +- .../fusion_g3/operators/op_dequantize.cpp | 24 +-- .../cadence/fusion_g3/operators/op_mul.cpp | 17 +- .../operators/op_native_layer_norm.cpp | 20 ++- .../fusion_g3/operators/op_quantize.cpp | 43 ++--- .../fusion_g3/operators/op_softmax.cpp | 18 +- .../cadence/fusion_g3/operators/operators.h | 2 + .../cadence/fusion_g3/operators/targets.bzl | 65 +++++--- .../fusion_g3/operators/tests/test_op_add.cpp | 35 +++- backends/cadence/runtime/TARGETS | 3 + backends/cadence/runtime/et_pal.cpp | 90 ++++++++++ backends/cadence/runtime/targets.bzl | 15 ++ .../_passes/int4_weight_only_quantizer.py | 6 + backends/vulkan/docs/android_demo.md | 1 + devtools/bundled_program/core.py | 41 ++++- .../bundled_program/test/test_bundle_data.py | 44 +++++ devtools/inspector/_inspector_utils.py | 18 +- .../inspector/tests/inspector_utils_test.py | 29 ++++ docs/source/getting-started-setup.md | 7 +- .../runtime-build-and-cross-compilation.md | 6 +- examples/arm/setup.sh | 2 +- .../android/ExecuTorchDemo/README.md | 7 +- .../docs/delegates/xnnpack_README.md | 6 +- .../LLaMA/docs/delegates/xnnpack_README.md | 6 +- examples/devtools/build_example_runner.sh | 2 +- examples/devtools/test_example_runner.sh | 2 +- examples/models/llama/README.md | 3 + examples/models/llama/llama_transformer.py | 5 +- examples/models/llama/model.py | 9 + examples/models/llama/rope.py | 14 +- .../xnnpack/quantization/test_quantize.sh | 5 +- install_requirements.py | 2 +- kernels/test/targets.bzl | 25 ++- test/build_size_test.sh | 2 +- 43 files changed, 689 insertions(+), 209 deletions(-) create mode 100644 backends/cadence/runtime/et_pal.cpp create mode 100644 backends/cadence/runtime/targets.bzl diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index f5b1824a78..5efa0825b9 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -79,9 +79,6 @@ RUN if [ -n "${ANDROID_NDK_VERSION}" ]; then bash ./install_android.sh; fi RUN rm install_android.sh ARG ARM_SDK -COPY --chown=ci-user:ci-user ./arm /opt/arm -# Set up ARM SDK if needed -RUN if [ -n "${ARM_SDK}" ]; then git config --global user.email "ossci@example.com"; git config --global user.name "OSS CI"; bash /opt/arm/setup.sh --i-agree-to-the-contained-eula /opt/arm-sdk; chown -R ci-user:ci-user /opt/arm-sdk; fi ARG QNN_SDK diff --git a/.ci/scripts/build_llama_android.sh b/.ci/scripts/build_llama_android.sh index a08fd5499f..6b8f851d77 100644 --- a/.ci/scripts/build_llama_android.sh +++ b/.ci/scripts/build_llama_android.sh @@ -12,7 +12,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" install_executorch_and_backend_lib() { echo "Installing executorch and xnnpack backend" - rm -rf cmake-android-out && mkdir cmake-android-out + clean_executorch_install_folders + mkdir cmake-android-out ANDROID_NDK=/opt/ndk BUCK2=buck2 ANDROID_ABI=arm64-v8a diff --git a/.ci/scripts/utils.sh b/.ci/scripts/utils.sh index 4c8cd086e4..be927a3a3a 100644 --- a/.ci/scripts/utils.sh +++ b/.ci/scripts/utils.sh @@ -16,6 +16,10 @@ retry () { "$@" || (sleep 30 && reset_buck && "$@") || (sleep 60 && reset_buck && "$@") } +clean_executorch_install_folders() { + ./install_requirements.sh --clean +} + install_executorch() { which pip # Install executorch, this assumes that Executorch is checked out in the @@ -74,7 +78,8 @@ build_executorch_runner_buck2() { build_executorch_runner_cmake() { CMAKE_OUTPUT_DIR=cmake-out # Build executorch runtime using cmake - rm -rf "${CMAKE_OUTPUT_DIR}" && mkdir "${CMAKE_OUTPUT_DIR}" + clean_executorch_install_folders + mkdir "${CMAKE_OUTPUT_DIR}" pushd "${CMAKE_OUTPUT_DIR}" || return # This command uses buck2 to gather source files and buck2 could crash flakily @@ -103,7 +108,7 @@ build_executorch_runner() { cmake_install_executorch_lib() { echo "Installing libexecutorch.a and libportable_kernels.a" - rm -rf cmake-out + clean_executorch_install_folders retry cmake -DBUCK2="$BUCK" \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index ea25a0a7d3..73af9842a2 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -7,8 +7,6 @@ on: - .ci/docker/** - .github/workflows/docker-builds.yml - requirements-lintrunner.txt - - examples/arm/setup.sh - - examples/arm/ethos-u-setup/** push: branches: - main @@ -17,8 +15,6 @@ on: - .ci/docker/** - .github/workflows/docker-builds.yml - requirements-lintrunner.txt - - examples/arm/setup.sh - - examples/arm/ethos-u-setup/** schedule: - cron: 1 3 * * 3 diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 786117e645..80c5f3c442 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -12,6 +12,7 @@ from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, + get_node_arg, insert_q_dq_pair, ) from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op @@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): return False - def insert_input_transpose(self, node, input_node, graph_module): + @staticmethod + def memory_format_differs(shape): + """Returns true if the shape will have a different memory layout in NCHW and NHWC format""" + if len(shape) >= 4: + C = shape[1] + H = shape[2] + W = shape[3] + elif len(shape) == 3: + C = shape[0] + H = shape[1] + W = shape[2] + if len(shape) <= 2: + return False + + return C > 1 and (H > 1 or W > 1) + + @staticmethod + def is_channel_reshape(input_shape, output_shape): + """Returns true if the reshape changes the channel dimension""" + if not len(input_shape) == len(output_shape) == 4: + return False + + C_old = input_shape[1] + C_new = output_shape[1] + + N_new = output_shape[0] + N_old = input_shape[0] + + return (N_old != N_new) or (C_old != C_new) + + @staticmethod + def insert_input_transpose(node, input_node, graph_module): quantize = input_node.target == dq_op q_params = input_node.args[1:] if quantize else None with graph_module.graph.inserting_before(node): permute_node = create_node( graph_module.graph, torch.ops.passthrough_to_tosa._transpose, - args=(input_node, list(self.NHWC_inverse_order)), + args=( + input_node, + list(AnnotateChannelsLastDimOrder.NHWC_inverse_order), + ), quantize=quantize, q_params=q_params, ) @@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module): range(len(input_node.meta["val"].size())) ) - def insert_output_transpose(self, node, graph_module): + @staticmethod + def insert_output_transpose(node, graph_module): with graph_module.graph.inserting_after(node): permute_node = create_node( graph_module.graph, torch.ops.passthrough_to_tosa._transpose, - args=(node, list(self.NHWC_order)), + args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)), + ) + permute_node.meta["tosa_dim_order"] = ( + AnnotateChannelsLastDimOrder.NHWC_order ) - permute_node.meta["tosa_dim_order"] = self.NHWC_order node.meta["tosa_dim_order"] = (0, 1, 2, 3) users = [user for user in node.users if user != permute_node] for user in users: @@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module): q_params = node.args[0].args[1:] insert_q_dq_pair(graph_module.graph, node, q_params) + @staticmethod + def _insert_squeeze_transpose( + input_shape, output_shape, node, input_node, graph_module + ): + nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3 + + if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs( + input_shape + ): + AnnotateChannelsLastDimOrder.insert_input_transpose( + node, input_node, graph_module + ) + + @staticmethod + def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module): + nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4 + if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs( + output_shape + ): + AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module) + + @staticmethod + def _insert_view_transpose( + input_shape, output_shape, node, input_node, graph_module + ): + nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4 + nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4 + channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape( + output_shape, input_shape + ) + + if ( + channel_reshape or nhwc_to_nchw + ) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape): + AnnotateChannelsLastDimOrder.insert_input_transpose( + node, input_node, graph_module + ) + if ( + channel_reshape or nchw_to_nhwc + ) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape): + AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module) + def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): """ - Reshape operations are not equivalent in NCHW and NHWC. - To get around this, transposes need to be added if the previous or new shape - fulfil the following condition: - C > 1 and (H or W > 1) - - This is relevant for the following operations; - squeeze: 4D -> 3D - unsqueeze: <4D -> 4D - view: <4D -> 4D - view: 4D -> <4D - view: 4D -> 4D - """ - - def transpose_condition(shape): - if len(shape) != 4: - return False - C = shape[1] - H = shape[2] - W = shape[3] - return C > 1 and (H > 1 or W > 1) + Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format. + This is relevant for the following cases: + - squeeze: 4D -> <4D + - unsqueeze: 3D -> 4D + - view: <4D -> 4D + - view: 4D -> <4D + Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case. + Transposes can be avoided for shapes where there is no difference in actual memory, e.g for + - H == W == 1 + - C == 1 + - 1D/2D tensors + """ for node in graph_module.graph.nodes: if node.op != "call_function": continue + if node.target == exir_ops.edge.aten.squeeze_copy.dims: input_node = node.args[0] input_shape = input_node.meta["val"].shape - if transpose_condition(input_shape): - self.insert_input_transpose(node, input_node, graph_module) + output_shape = node.meta["val"].shape + + self._insert_squeeze_transpose( + input_shape, output_shape, node, input_node, graph_module + ) elif node.target == exir_ops.edge.aten.unsqueeze_copy.default: + input_node = get_node_arg(node.args, 0, default_value=False) + if input_node: + input_shape = input_node.meta["val"].shape + else: + input_shape = () output_shape = node.meta["val"].shape - if transpose_condition(output_shape): - self.insert_output_transpose(node, graph_module) + + self._insert_unsqueeze_transpose( + input_shape, output_shape, node, graph_module + ) elif node.target == exir_ops.edge.aten.view_copy.default: input_node = node.args[0] + input_shape = input_node.meta["val"].shape + output_shape = node.meta["val"].shape - old_shape = input_node.meta["val"].shape - new_shape = node.meta["val"].shape - - if transpose_condition(old_shape): - self.insert_input_transpose(node, input_node, graph_module) - - if transpose_condition(new_shape): - self.insert_output_transpose(node, graph_module) + self._insert_view_transpose( + input_shape, output_shape, node, input_node, graph_module + ) def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index a4d3bc5adf..502764fe05 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -157,9 +157,9 @@ def test_layer_norm_tosa_BI( # Numerical issues on FVP likely due to mul op, MLETORCH-521 # Skip tests that require transposes. - @parameterized.expand(test_data_suite[:-2]) + @parameterized.expand(test_data_suite) @unittest.expectedFailure - def test_layer_norm_u55_BI( + def test_layer_norm_u55_BI_xfails( self, test_name: str, test_data: torch.Tensor, @@ -171,7 +171,8 @@ def test_layer_norm_u55_BI( # Numerical issues on FVP likely due to mul op, MLETORCH-521 @parameterized.expand(test_data_suite[:-2]) - def test_layer_norm_u85_BI_fvp( + @unittest.expectedFailure + def test_layer_norm_u85_BI_xfails( self, test_name: str, test_data: torch.Tensor, @@ -182,7 +183,6 @@ def test_layer_norm_u85_BI_fvp( ) @parameterized.expand(test_data_suite[-2:]) - @unittest.skip # Flaky def test_layer_norm_u85_BI( self, test_name: str, diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 8499512e10..a9d827f903 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -56,7 +56,6 @@ def _test_to_copy_tosa_MI_pipeline( ) .export() .dump_artifact() - .check_count({"torch.ops.aten._to_copy.default": 1}) .to_edge() .dump_artifact() .partition() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 09a8f57bd3..07a32fe595 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -43,6 +43,7 @@ class View(torch.nn.Module): (torch.rand(1, 1, 5, 10), (1, 1, 50, 1)), (torch.rand(5, 10, 1, 1), (1, 25, 2)), (torch.rand(2, 50, 1, 1), (1, 100)), + (torch.rand(2, 3, 2, 3), (2, 3, 3, 2)), ] def forward(self, x: torch.Tensor, new_shape): diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp index 0a7c7e7e03..a68cef54b4 100644 --- a/backends/cadence/fusion_g3/operators/op_add.cpp +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -6,25 +6,37 @@ * LICENSE file in the root directory of this source tree. */ +#include + +#include + #include #include #include #include #include -#include -using exec_aten::Scalar; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using executorch::runtime::canCast; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::canCast; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; namespace cadence { namespace impl { namespace G3 { namespace native { +#define XT_KERNEL_CHECK(ctx, out, kernel, ...) \ + const auto ret = kernel(__VA_ARGS__); \ + ET_KERNEL_CHECK_MSG( \ + ctx, \ + ret == 0, \ + InvalidArgument, \ + out, \ + "Failed to run kernel: " #kernel "(" #__VA_ARGS__ ")"); + Tensor& add_out( KernelRuntimeContext& ctx, const Tensor& a, @@ -121,13 +133,30 @@ Tensor& add_out( torch::executor::native::utils::extract_scalar(alpha, &alpha_val); if ((a.numel() == 1) && (alpha_val == 1)) { - xa_nn_elm_add_scalar_32x32_32( - out_data, inp2_data, inp1_data[0], alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_scalar_32x32_32, + out_data, + inp2_data, + inp1_data[0], + alpha_val, + out.numel()); } else if (b.numel() == 1) { - xa_nn_elm_add_scalar_32x32_32( - out_data, inp1_data, inp2_data[0], alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_scalar_32x32_32, + out_data, + inp1_data, + inp2_data[0], + alpha_val, + out.numel()); } else if (broadcast) { - xa_nn_elm_add_broadcast_5D_32x32_32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_broadcast_5D_32x32_32, out_data, out_shape, inp1_data, @@ -137,8 +166,15 @@ Tensor& add_out( max_dim, alpha_val); } else { - xa_nn_elm_add_32x32_32( - out_data, inp1_data, inp2_data, alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_32x32_32, + out_data, + inp1_data, + inp2_data, + alpha_val, + out.numel()); } } else if ((compute_type == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr(); @@ -149,13 +185,30 @@ Tensor& add_out( torch::executor::native::utils::extract_scalar(alpha, &alpha_val); if ((a.numel() == 1) && (alpha_val == 1.0)) { - xa_nn_elm_add_scalar_f32xf32_f32( - out_data, inp2_data, inp1_data[0], alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_scalar_f32xf32_f32, + out_data, + inp2_data, + inp1_data[0], + alpha_val, + out.numel()); } else if (b.numel() == 1) { - xa_nn_elm_add_scalar_f32xf32_f32( - out_data, inp1_data, inp2_data[0], alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_data[0], + alpha_val, + out.numel()); } else if (broadcast) { - xa_nn_elm_add_broadcast_5D_f32xf32_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_broadcast_5D_f32xf32_f32, out_data, out_shape, inp1_data, @@ -165,8 +218,15 @@ Tensor& add_out( max_dim, alpha_val); } else { - xa_nn_elm_add_f32xf32_f32( - out_data, inp1_data, inp2_data, alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_f32xf32_f32, + out_data, + inp1_data, + inp2_data, + alpha_val, + out.numel()); } } else { ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { @@ -242,8 +302,15 @@ Tensor& add_scalar_out( int* const out_data = out.mutable_data_ptr(); - xa_nn_elm_add_scalar_32x32_32( - out_data, inp1_data, inp2_val, alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_scalar_32x32_32, + out_data, + inp1_data, + inp2_val, + alpha_val, + out.numel()); } else if (compute_type == ScalarType::Float) { const float* const inp1_data = a.const_data_ptr(); @@ -255,8 +322,15 @@ Tensor& add_scalar_out( float* const out_data = out.mutable_data_ptr(); - xa_nn_elm_add_scalar_f32xf32_f32( - out_data, inp1_data, inp2_val, alpha_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_val, + alpha_val, + out.numel()); } else { ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp index 7fae3fa29c..f0f327c024 100644 --- a/backends/cadence/fusion_g3/operators/op_cat.cpp +++ b/backends/cadence/fusion_g3/operators/op_cat.cpp @@ -6,16 +6,17 @@ * LICENSE file in the root directory of this source tree. */ +#include + +#include + #include #include -#include -#include -using exec_aten::Scalar; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is @@ -194,4 +195,4 @@ Tensor& cat_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp index cd5d4a753e..ed5b3125ac 100644 --- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -6,18 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include #include #include #include -using exec_aten::Scalar; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +#include + +#include +#include + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; template using optional = exec_aten::optional; @@ -185,7 +187,7 @@ void dequantize_impl( if (axis == NULL) { // calculate the dequantized output, cast scale to float to match fbgemm // behavior -#define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ +#define ASYM_DEQUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ case ScalarType::out_dtype: { \ /* Hoist these function calls out of our inner loop because they might not \ * get inlined without LTO, particularly in ATen mode. */ \ @@ -201,7 +203,7 @@ void dequantize_impl( #define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TENSOR); \ default: \ ET_CHECK_MSG( \ false, \ @@ -219,7 +221,7 @@ void dequantize_impl( static_cast(input.scalar_type())); } #undef ASYM_CALCULATE_INT_TYPE_TENSOR -#undef ASYM_DEQUANTIZE_IMPL_TESNOR +#undef ASYM_DEQUANTIZE_IMPL_TENSOR } else { // a list contains all dimensions except axis int64_t dims[input.dim() - 1]; diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp index 914ecf9d7e..840cb16c7c 100644 --- a/backends/cadence/fusion_g3/operators/op_mul.cpp +++ b/backends/cadence/fusion_g3/operators/op_mul.cpp @@ -6,18 +6,19 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include #include -#include -using exec_aten::Scalar; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using executorch::runtime::canCast; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::canCast; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; namespace cadence { namespace impl { @@ -238,4 +239,4 @@ Tensor& mul_scalar_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp index 68d111795c..a5fbe31eee 100644 --- a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp +++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp @@ -6,18 +6,20 @@ * LICENSE file in the root directory of this source tree. */ +#include +#include + +#include + #include #include #include -#include -#include -#include -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; -using IntArrayRef = exec_aten::ArrayRef; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; namespace cadence { namespace impl { @@ -255,4 +257,4 @@ std::tuple native_layer_norm_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp index 399e1be25a..fc206b67cd 100644 --- a/backends/cadence/fusion_g3/operators/op_quantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -6,18 +6,21 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include #include #include #include -using exec_aten::Scalar; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +#include + +#include +#include + +using ::executorch::aten::ArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is @@ -142,7 +145,7 @@ void quantize_impl( int* axis, int quant_min, int quant_max) { - const exec_aten::ArrayRef input_size = input.sizes(); + const ArrayRef input_size = input.sizes(); int kTensorDimensionLimit = 5; @@ -301,8 +304,8 @@ void quantize_impl( } } - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + optional> optional_dim_list{ + ArrayRef{dims, size_t(input.dim() - 1)}}; // Actual quantization logic // input, out are the input and output tensors @@ -487,8 +490,8 @@ void quantize_impl( } } - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + optional> optional_dim_list{ + ArrayRef{dims, size_t(input.dim() - 1)}}; // Actual quantization logic // input, out are the input and output tensors @@ -565,9 +568,9 @@ Tensor& quantize_per_tensor_out( int64_t quant_max, ScalarType dtype, Tensor& out) { - torch::executor::Error err = resize_tensor(out, input.sizes()); + Error err = resize_tensor(out, input.sizes()); ET_CHECK_MSG( - err == torch::executor::Error::Ok, + err == Error::Ok, "Failed to resize out Tensor in quantize_per_tensor_out"); // check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); @@ -600,7 +603,7 @@ Tensor& quantize_per_tensor_tensor_args_out( // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal // failures. if (scale.scalar_type() != ScalarType::Double) { - context.fail(torch::executor::Error::InvalidArgument); + context.fail(Error::InvalidArgument); return out; } ET_CHECK_MSG( @@ -657,7 +660,7 @@ Tensor& quantize_per_channel_out( int64_t quant_max, ScalarType dtype, Tensor& out) { - torch::executor::Error err = resize_tensor(out, input.sizes()); + Error err = resize_tensor(out, input.sizes()); // normalize axis ET_CHECK_MSG( @@ -671,7 +674,7 @@ Tensor& quantize_per_channel_out( } ET_CHECK_MSG( - err == torch::executor::Error::Ok, + err == Error::Ok, "Failed to resize out Tensor in quantize_per_channel_out"); ET_CHECK_MSG( @@ -776,9 +779,9 @@ Tensor& quantize_per_token_out( input_strides.data(), executorch::runtime::TensorShapeDynamism::STATIC); Tensor reshaped_input(&reshaped_input_impl); - torch::executor::Error err = resize_tensor(out, input.sizes()); + Error err = resize_tensor(out, input.sizes()); ET_CHECK_MSG( - err == torch::executor::Error::Ok, + err == Error::Ok, "Failed to resize out Tensor in quantize_per_channel_out"); #endif diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp index 9b51bdd22a..9f34348150 100644 --- a/backends/cadence/fusion_g3/operators/op_softmax.cpp +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -6,18 +6,20 @@ * LICENSE file in the root directory of this source tree. */ +#include + +#include + #include #include #include #include -#include -#include -using exec_aten::Scalar; -using exec_aten::ScalarType; -using exec_aten::Tensor; -using torch::executor::Error; -using torch::executor::KernelRuntimeContext; +using ::executorch::aten::ArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; namespace cadence { namespace impl { @@ -51,7 +53,7 @@ Tensor& _softmax_out( dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; int inp_shapes[in.dim()]; - const exec_aten::ArrayRef in_size = in.sizes(); + const ArrayRef in_size = in.sizes(); for (int i = 0; i < in.dim(); i++) { inp_shapes[i] = in_size[i]; } diff --git a/backends/cadence/fusion_g3/operators/operators.h b/backends/cadence/fusion_g3/operators/operators.h index fc4d5ff625..9d7f7b9c30 100644 --- a/backends/cadence/fusion_g3/operators/operators.h +++ b/backends/cadence/fusion_g3/operators/operators.h @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#pragma once + #include #include diff --git a/backends/cadence/fusion_g3/operators/targets.bzl b/backends/cadence/fusion_g3/operators/targets.bzl index 9318b36937..47d035d420 100644 --- a/backends/cadence/fusion_g3/operators/targets.bzl +++ b/backends/cadence/fusion_g3/operators/targets.bzl @@ -1,6 +1,45 @@ load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def define_operator(name: str, deps: list[str] | None = None) -> None: + op_name = "op_{}".format(name) + + # Deps used by all operators. + common_deps = [ + "//executorch/kernels/portable/cpu/util:all_deps", + "//executorch/kernels/portable/cpu/pattern:all_deps", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu:scalar_utils", + "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common", + "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib", + ] + if deps == None: + deps = [] + + runtime.cxx_library( + name = op_name, + srcs = [op_name + ".cpp"], + platforms = CXX, + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + deps = deps + common_deps, + exported_deps = [ + ":operators_header", + ], + ) + +OPERATORS = [ + "add", + "cat", + "dequantize", + "mul", + "native_layer_norm", + "quantize", + "softmax", +] + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -11,28 +50,16 @@ def define_common_targets(): # Define build targets for all operators registered in the tables above. runtime.cxx_library( - name = "cadence_g3_ops", - srcs = glob([ - "*.cpp", - ]), - exported_headers = glob([ - "*.h", - ]), - platforms = CXX, - deps = [ - "//executorch/kernels/portable/cpu/util:all_deps", - "//executorch/kernels/portable/cpu/pattern:all_deps", - "//executorch/runtime/kernel:kernel_includes", - "//executorch/kernels/portable/cpu:scalar_utils", - "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common", - "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib", - ], + name = "operators_header", + exported_headers = ["operators.h"], visibility = [ "//executorch/backends/cadence/...", - "@EXECUTORCH_CLIENTS", ], exported_deps = [ - "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common", - "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", ], ) + + for op in OPERATORS: + define_operator(op) diff --git a/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp b/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp index 06bf4bf4ec..bba778035b 100644 --- a/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp +++ b/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp @@ -8,8 +8,12 @@ #include #include +#include +#include #include +#include +#include #include #include #include @@ -24,24 +28,19 @@ namespace { using ::executorch::aten::Scalar; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; using ::executorch::runtime::runtime_init; using ::executorch::runtime::testing::TensorFactory; -using ::testing::Test; -class FusionG3OperatorTest : public Test { +class FusionG3OperatorTest : public OperatorTest { public: - void SetUp() override { - runtime_init(); - } - protected: Tensor& add_out(const Tensor& a, const Tensor& b, const Scalar& alpha, Tensor& out) { return cadence::impl::G3::native::add_out(context_, a, b, alpha, out); } - - KernelRuntimeContext context_; }; TEST_F(FusionG3OperatorTest, TwoDimFloatTensorAddTest) { @@ -77,6 +76,26 @@ TEST_F(FusionG3OperatorTest, AddWithBroadcastTest) { EXPECT_TENSOR_EQ(out, tf.full(size_a, 2)); } +TEST_F(FusionG3OperatorTest, KernelCheckTest) { + TensorFactory tf; + // Broadcast add. + const std::vector sizeOfA{1, 3, 2, 4}, sizeOfB{2, 4}; + const Tensor b = tf.ones(sizeOfB); + Tensor out = tf.zeros(sizeOfA); + // Create a null tensor to force kernel check failure. + TensorImpl nullTensorImpl( + b.scalar_type(), + b.dim(), + const_cast(b.sizes().data()), + // Use nullptr to force kernel check failure. + /*data=*/nullptr, + const_cast(b.dim_order().data())); + Tensor nullTensor(&nullTensorImpl); + + ET_EXPECT_KERNEL_FAILURE( + context_, add_out(tf.ones(sizeOfA), nullTensor, 1, out)); +} + } // namespace } // namespace native } // namespace G3 diff --git a/backends/cadence/runtime/TARGETS b/backends/cadence/runtime/TARGETS index 95a7bdc369..4055f1922a 100644 --- a/backends/cadence/runtime/TARGETS +++ b/backends/cadence/runtime/TARGETS @@ -1,3 +1,4 @@ +load(":targets.bzl", "define_common_targets") load("@fbcode_macros//build_defs:python_library.bzl", "python_library") oncall("odai_jarvis") @@ -22,3 +23,5 @@ python_library( "//executorch/exir:lib", ], ) + +define_common_targets() diff --git a/backends/cadence/runtime/et_pal.cpp b/backends/cadence/runtime/et_pal.cpp new file mode 100644 index 0000000000..fdf058f05b --- /dev/null +++ b/backends/cadence/runtime/et_pal.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#if defined(XTENSA) + +#include +#include + +#include + +#include + +#define ET_LOG_OUTPUT_FILE stdout + +void et_pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + ET_UNUSED const char* function, + size_t line, + const char* message, + ET_UNUSED size_t length) { + // Not all platforms have ticks == nanoseconds, but this one does. + timestamp /= 1000; // To microseconds + int us = timestamp % 1000000; + timestamp /= 1000000; // To seconds + int sec = timestamp % 60; + timestamp /= 60; // To minutes + int min = timestamp % 60; + timestamp /= 60; // To hours + int hour = timestamp; + + fprintf( + ET_LOG_OUTPUT_FILE, + "%c %02d:%02d:%02d.%06d executorch:%s:%d] %s\n", + static_cast(level), + hour, + min, + sec, + us, + filename, + static_cast(line), + message); + fflush(ET_LOG_OUTPUT_FILE); +} + +et_timestamp_t et_pal_current_ticks(void) { + struct tms curr_time; + times(&curr_time); + return curr_time.tms_utime; +} + +void et_pal_init(void) { + xt_iss_client_command("all", "enable"); +} + +#else + +#include + +#include +#include + +#include + +#define ET_LOG_OUTPUT_FILE stderr + +#define NSEC_PER_USEC 1000UL +#define USEC_IN_SEC 1000000UL +#define NSEC_IN_USEC 1000UL +#define NSEC_IN_SEC (NSEC_IN_USEC * USEC_IN_SEC) + +et_timestamp_t et_pal_current_ticks(void) { + struct timespec ts; + auto ret = clock_gettime(CLOCK_REALTIME, &ts); + if (ret != 0) { + fprintf(ET_LOG_OUTPUT_FILE, "Could not get time\n"); + fflush(ET_LOG_OUTPUT_FILE); + std::abort(); + } + + return ((ts.tv_sec * NSEC_IN_SEC) + (ts.tv_nsec)); +} + +#endif diff --git a/backends/cadence/runtime/targets.bzl b/backends/cadence/runtime/targets.bzl new file mode 100644 index 0000000000..dabe42ad82 --- /dev/null +++ b/backends/cadence/runtime/targets.bzl @@ -0,0 +1,15 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "et_pal", + srcs = ["et_pal.cpp"], + link_whole = True, + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS" + ], + exported_deps = [ + "//executorch/runtime/platform:platform", + ], + ) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 895e1c2ad8..4821b61340 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -226,6 +226,12 @@ def _create_quantized_state_dict( self.groupsize, self.precision, # dtype for scales_and_zeros ) + # If the packing of 2 4-bit values into a single 8-bit value was not + # performed in the previous function call, then do it manually now. + if w_int4x8.shape == weight.shape: + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to( + torch.uint8 + ) # In the original implementation, w_int4x8 is packed via calling the # _convert_weight_to_int4pack operator before storing the weight. However # the Vulkan implementation does not expect the weights to be packed, so diff --git a/backends/vulkan/docs/android_demo.md b/backends/vulkan/docs/android_demo.md index 2a4faacc0c..9b45fcde9a 100644 --- a/backends/vulkan/docs/android_demo.md +++ b/backends/vulkan/docs/android_demo.md @@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan. # The files will usually be downloaded to ~/.llama python -m examples.models.llama.export_llama \ --disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \ + --model "llama3_2" \ -c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \ -p ~/.llama/checkpoints/Llama3.2-1B/params.json \ --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' diff --git a/devtools/bundled_program/core.py b/devtools/bundled_program/core.py index c775fb1510..2c930f06b7 100644 --- a/devtools/bundled_program/core.py +++ b/devtools/bundled_program/core.py @@ -19,7 +19,7 @@ from executorch.devtools.bundled_program.version import BUNDLED_PROGRAM_SCHEMA_VERSION from executorch.exir import ExecutorchProgram, ExecutorchProgramManager -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _deserialize_pte_binary, _serialize_pte_binary from executorch.exir.tensor import get_scalar_type, scalar_type_enum, TensorSpec # pyre-ignore @@ -43,23 +43,43 @@ class BundledProgram: def __init__( self, - executorch_program: Union[ - ExecutorchProgram, - ExecutorchProgramManager, + executorch_program: Optional[ + Union[ + ExecutorchProgram, + ExecutorchProgramManager, + ] ], method_test_suites: Sequence[MethodTestSuite], + pte_file_path: Optional[str] = None, ): """Create BundledProgram by bundling the given program and method_test_suites together. Args: executorch_program: The program to be bundled. method_test_suites: The testcases for certain methods to be bundled. + pte_file_path: The path to pte file to deserialize program if executorch_program is not provided. """ + if not executorch_program and not pte_file_path: + raise RuntimeError( + "Either executorch_program or pte_file_path must be provided" + ) + + if executorch_program and pte_file_path: + raise RuntimeError( + "Only one of executorch_program or pte_file_path can be used" + ) method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name) - self._assert_valid_bundle(executorch_program, method_test_suites) + if executorch_program: + self._assert_valid_bundle(executorch_program, method_test_suites) + self.executorch_program: Optional[ + Union[ + ExecutorchProgram, + ExecutorchProgramManager, + ] + ] = executorch_program + self._pte_file_path: Optional[str] = pte_file_path - self.executorch_program = executorch_program self.method_test_suites = method_test_suites # This is the cache for bundled program in schema type. @@ -72,7 +92,14 @@ def serialize_to_schema(self) -> bp_schema.BundledProgram: if self._bundled_program_in_schema is not None: return self._bundled_program_in_schema - program = self._extract_program(self.executorch_program) + if self.executorch_program: + program = self._extract_program(self.executorch_program) + else: + assert self._pte_file_path is not None + with open(self._pte_file_path, "rb") as f: + p_bytes = f.read() + program = _deserialize_pte_binary(p_bytes) + bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = [] # Emit data and metadata of bundled tensor diff --git a/devtools/bundled_program/test/test_bundle_data.py b/devtools/bundled_program/test/test_bundle_data.py index 565539cbf1..b833903c2f 100644 --- a/devtools/bundled_program/test/test_bundle_data.py +++ b/devtools/bundled_program/test/test_bundle_data.py @@ -6,6 +6,7 @@ # pyre-strict +import tempfile import unittest from typing import List @@ -74,6 +75,49 @@ def test_bundled_program(self) -> None: bytes(_serialize_pte_binary(executorch_program.executorch_program)), ) + def test_bundled_program_from_pte(self) -> None: + executorch_program, method_test_suites = get_common_executorch_program() + + with tempfile.TemporaryDirectory() as tmp_dir: + executorch_model_path = f"{tmp_dir}/executorch_model.pte" + with open(executorch_model_path, "wb") as f: + f.write(executorch_program.buffer) + + bundled_program = BundledProgram( + executorch_program=None, + method_test_suites=method_test_suites, + pte_file_path=executorch_model_path, + ) + + method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name) + + for plan_id in range( + len(executorch_program.executorch_program.execution_plan) + ): + bundled_plan_test = ( + bundled_program.serialize_to_schema().method_test_suites[plan_id] + ) + method_test_suite = method_test_suites[plan_id] + + self.assertEqual( + len(bundled_plan_test.test_cases), len(method_test_suite.test_cases) + ) + for bundled_program_ioset, method_test_case in zip( + bundled_plan_test.test_cases, method_test_suite.test_cases + ): + self.assertIOsetDataEqual( + bundled_program_ioset.inputs, method_test_case.inputs + ) + self.assertIOsetDataEqual( + bundled_program_ioset.expected_outputs, + method_test_case.expected_outputs, + ) + + self.assertEqual( + bundled_program.serialize_to_schema().program, + bytes(_serialize_pte_binary(executorch_program.executorch_program)), + ) + def test_bundled_miss_methods(self) -> None: executorch_program, method_test_suites = get_common_executorch_program() diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 4944141631..3faf62e008 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str): def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput): def mean_squared_error(a: torch.Tensor, b: torch.Tensor): - return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2) + return round((torch.pow((a - b), 2)).mean().item(), 2) results = [] for ref_value, value in zip(ref_values, values): # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): - results.append(mean_squared_error(ref_value, value)) + results.append( + mean_squared_error(ref_value.to(torch.float32), value.to(torch.float32)) + ) else: results.append(None) @@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor): def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput): def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor): - signal = signal.type(torch.float32) - noise = noise.type(torch.float32) signal_power = torch.mean(torch.pow(signal, 2)) noise_power = torch.mean(torch.pow(noise, 2)) snr = 10 * torch.log10(signal_power / noise_power) @@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor): for ref_value, value in zip(ref_values, values): # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): - diff = ref_value - value - snr = signal_to_noise(ref_value, diff) + ref_value_fp = ref_value.to(torch.float32) + value_fp = value.to(torch.float32) + diff = ref_value_fp - value_fp + snr = signal_to_noise(ref_value_fp, diff) results.append(snr) else: results.append(None) @@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor): for ref_value, value in zip(ref_values, values): # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): - results.append(cosine_similarity(ref_value, value)) + results.append( + cosine_similarity(ref_value.to(torch.float32), value.to(torch.float32)) + ) else: results.append(None) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 73511f5fcd..5e224415bb 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -25,6 +25,9 @@ from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord from executorch.devtools.inspector._inspector_utils import ( + calculate_cosine_similarity, + calculate_mse, + calculate_snr, calculate_time_scale_factor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, @@ -188,6 +191,32 @@ def test_calculate_time_scale_factor_cycles(self): calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1 ) + def test_compare_results(self): + a = torch.rand(4, 4) + + # Create tensor b which has very close value to tensor a + b = a.clone() + b[0, 0] += 1e-2 + b[1, 0] += 1e-2 + b[1, 3] -= 1e-2 + + self.assertLess(calculate_mse([a], [b])[0], 0.5) + self.assertGreater(calculate_snr([a], [b])[0], 30.0) + self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0) + + def test_compare_results_uint8(self): + a = torch.randint(0, 255, (4, 4), dtype=torch.uint8) + + # Create tensor b which has very close value to tensor a + b = a.clone() + b[0, 0] += 1 + b[1, 0] += 1 + b[1, 3] -= 1 + + self.assertLess(calculate_mse([a], [b])[0], 0.5) + self.assertGreater(calculate_snr([a], [b])[0], 30.0) + self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/docs/source/getting-started-setup.md b/docs/source/getting-started-setup.md index f2f98d6563..66ef2233af 100644 --- a/docs/source/getting-started-setup.md +++ b/docs/source/getting-started-setup.md @@ -113,7 +113,7 @@ to ExecuTorch. > > ```bash > # From the root of the executorch repo: -> rm -rf cmake-out pip-out +> ./install_requirements.sh --clean > git submodule sync > git submodule update --init > ``` @@ -196,7 +196,8 @@ The ExecuTorch repo uses CMake to build its C++ code. Here, we'll configure it t ```bash # Clean and configure the CMake build system. Compiled programs will # appear in the executorch/cmake-out directory we create here. - (rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..) + ./install_requirements.sh --clean + (mkdir cmake-out && cd cmake-out && cmake ..) # Build the executor_runner target cmake --build cmake-out --target executor_runner -j9 @@ -213,7 +214,7 @@ The ExecuTorch repo uses CMake to build its C++ code. Here, we'll configure it t > > ```bash > # From the root of the executorch repo: -> rm -rf cmake-out pip-out +> ./install_requirements.sh --clean > git submodule sync > git submodule update --init > ``` diff --git a/docs/source/runtime-build-and-cross-compilation.md b/docs/source/runtime-build-and-cross-compilation.md index 4d42357618..f30d2d28d1 100644 --- a/docs/source/runtime-build-and-cross-compilation.md +++ b/docs/source/runtime-build-and-cross-compilation.md @@ -45,7 +45,8 @@ cd executorch # Clean and configure the CMake build system. It's good practice to do this # whenever cloning or pulling the upstream repo. -(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..) +./install_requirements.sh --clean +(mkdir cmake-out && cd cmake-out && cmake ..) ``` Once this is done, you don't need to do it again until you pull from the upstream repo again, or if you modify any CMake-related files. @@ -121,7 +122,8 @@ Following are instruction on how to perform cross compilation for Android and iO Assuming Android NDK is available, run: ```bash # Run the following lines from the `executorch/` folder -rm -rf cmake-android-out && mkdir cmake-android-out && cd cmake-android-out +./install_requirements.sh --clean +mkdir cmake-android-out && cd cmake-android-out # point -DCMAKE_TOOLCHAIN_FILE to the location where ndk is installed cmake -DCMAKE_TOOLCHAIN_FILE=/Users/{user_name}/Library/Android/sdk/ndk/25.2.9519653/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a .. diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 3d875e656b..9e20d6043a 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -92,7 +92,7 @@ tosa_reference_model_rev="c5570b79e90c3a36ab8c4ddb8ee3fbc2cd3f7c38" # vela vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela" -vela_rev="a08fc18780827b5fefc814dd0162ee6317ce0ae7" +vela_rev="5427dc7e9c1a4c7d554163290faeea75f168772d" ######## ### Mandatory user args diff --git a/examples/demo-apps/android/ExecuTorchDemo/README.md b/examples/demo-apps/android/ExecuTorchDemo/README.md index a60307dd90..33107cbe5e 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/README.md +++ b/examples/demo-apps/android/ExecuTorchDemo/README.md @@ -69,7 +69,9 @@ We build the required ExecuTorch runtime library to run the model. export ANDROID_NDK= export ANDROID_ABI=arm64-v8a -rm -rf cmake-android-out && mkdir cmake-android-out +# Run the following lines from the `executorch/` folder +./install_requirements.sh --clean +mkdir cmake-android-out # Build the core executorch library cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \ @@ -112,7 +114,8 @@ export ANDROID_NDK= export ANDROID_ABI=arm64-v8a export QNN_SDK_ROOT= -rm -rf cmake-android-out && mkdir cmake-android-out && cd cmake-android-out +./install_requirements.sh --clean +mkdir cmake-android-out && cd cmake-android-out cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \ -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="${ANDROID_ABI}" \ diff --git a/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md b/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md index 2a6ddbbfe0..087bd24260 100644 --- a/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md @@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" ``` ### For Llama 3.2 1B and 3B QAT+LoRA models Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" ``` ### For Llama 3.2 1B and 3B BF16 models @@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" ``` For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-). diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md index 1968c55076..83b1ba76f8 100644 --- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md @@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" ``` ### For Llama 3.2 1B and 3B QAT+LoRA models Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" ``` ### For Llama 3.2 1B and 3B BF16 models @@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" ``` For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-). diff --git a/examples/devtools/build_example_runner.sh b/examples/devtools/build_example_runner.sh index 9f35abb1a3..693996940d 100755 --- a/examples/devtools/build_example_runner.sh +++ b/examples/devtools/build_example_runner.sh @@ -37,7 +37,7 @@ done main() { cd "${EXECUTORCH_ROOT}" - rm -rf cmake-out + ./install_requirements.sh --clean if [[ "${BUILD_COREML}" == "ON" ]]; then cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ diff --git a/examples/devtools/test_example_runner.sh b/examples/devtools/test_example_runner.sh index 9c9ed782cb..b16d4f3e04 100644 --- a/examples/devtools/test_example_runner.sh +++ b/examples/devtools/test_example_runner.sh @@ -16,7 +16,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/../../.ci/scripts/utils.sh" cmake_install_executorch_devtools_lib() { echo "Installing libexecutorch.a, libportable_kernels.a, libetdump.a, libbundled_program.a" - rm -rf cmake-out + clean_executorch_install_folders retry cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index cfa0fe04b1..e621ce5d49 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth LLAMA_PARAMS=path/to/params.json python -m examples.models.llama.export_llama \ + --model "llama3_2" \ --checkpoint "${LLAMA_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ -kv \ @@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth LLAMA_PARAMS=path/to/spinquant/params.json python -m examples.models.llama.export_llama \ + --model "llama3_2" \ --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ --use_sdpa_with_kv_cache \ @@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth LLAMA_PARAMS=path/to/qlora/params.json python -m examples.models.llama.export_llama \ + --model "llama3_2" \ --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ -qat \ diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 10d660d37a..aaef3cd980 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -113,6 +113,7 @@ class ModelArgs: ) rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. + rope_scale_factor: int = 8 # Additional Model Metadata needed at runtime bos_idx: int = 1 eos_idx: int = 3 @@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs): self.precompute_freqs_cis = hf_precompute_freqs_cis else: self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + precompute_freqs_cis, + use_scaled=self.params.use_scaled_rope, + scale_factor=self.params.rope_scale_factor, ) freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2385aba6d5..9f7994916a 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -145,6 +145,15 @@ def __init__(self, **kwargs): enable_dynamic_shape=self.enable_dynamic_shape, **params, ) + + if model_args.use_scaled_rope: + # Older models don't have use_scaled_rope configuration + assert self.args.model not in ["llama2", "stories110m"] + + # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor + if self.args.model not in ["llama3", "llama3_1"]: + model_args.rope_scale_factor = 32 + if kwargs.get("verbose", False): print("============= weights ================") print("{key} : {weights.numel()} : {weights.size()}") diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 1445787f5e..cd3ddb0d3b 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -8,16 +8,15 @@ # Different RoPE implementations import math -from typing import Tuple +from typing import Optional, Tuple import torch # ======================== Stock Implementation ======================== -def apply_scaling(freqs: torch.Tensor): +def apply_scaling(freqs: torch.Tensor, scale_factor: int): # Values obtained from grid search - scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 # original llama3 length @@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor): def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False + dim: int, + end: int, + theta: float = 10000.0, + use_scaled: bool = False, + scale_factor: Optional[int] = None, ): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) ) t = torch.arange(end, device=freqs.device) # pyre-ignore if use_scaled: - freqs = apply_scaling(freqs) # pyre-ignore + assert scale_factor is not None + freqs = apply_scaling(freqs, scale_factor) # pyre-ignore freqs = torch.outer(t, freqs).float() freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) diff --git a/examples/xnnpack/quantization/test_quantize.sh b/examples/xnnpack/quantization/test_quantize.sh index d870fb6faf..d439fde6cb 100644 --- a/examples/xnnpack/quantization/test_quantize.sh +++ b/examples/xnnpack/quantization/test_quantize.sh @@ -47,8 +47,9 @@ test_cmake_quantization() { SITE_PACKAGES="$(${PYTHON_EXECUTABLE} -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')" CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch" - (rm -rf cmake-out \ - && mkdir cmake-out \ + clean_executorch_install_folders + + (mkdir cmake-out \ && cd cmake-out \ && retry cmake \ -DCMAKE_BUILD_TYPE=Release \ diff --git a/install_requirements.py b/install_requirements.py index 90e1037329..ace2f34b70 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -85,7 +85,7 @@ def python_is_compatible(): print("Cleaning build artifacts...") print("Cleaning pip-out/...") shutil.rmtree("pip-out/", ignore_errors=True) - dirs = glob.glob("cmake-out*/") + dirs = glob.glob("cmake-out*/") + glob.glob("cmake-android-out/") for d in dirs: print(f"Cleaning {d}...") shutil.rmtree(d, ignore_errors=True) diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 2dd019e1b3..18fa646aec 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -41,6 +41,29 @@ def define_common_targets(): for aten_kernel in (True, False): aten_suffix = "_aten" if aten_kernel else "" + runtime.cxx_library( + name = "gtest_utils" + aten_suffix, + exported_headers=[ + "TestUtil.h", + ], + visibility = [ + "//executorch/kernels/...", + "@EXECUTORCH_CLIENTS", + ], + preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_kernel else [], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/test/utils:utils" + aten_suffix, + "//executorch/runtime/platform:pal_interface", + ], + fbcode_exported_deps = [ + "//common/gtest:gtest", + ], + xplat_exported_deps = [ + "//third-party/googletest:gtest_main", + ], + ) runtime.cxx_library( name = "test_util" + aten_suffix, srcs = [ @@ -49,7 +72,6 @@ def define_common_targets(): ], exported_headers = [ "BinaryLogicalOpTest.h", - "TestUtil.h", "UnaryUfuncRealHBBF16ToFloatHBF16Test.h", ], visibility = [ @@ -59,6 +81,7 @@ def define_common_targets(): preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_kernel else [], exported_deps = [ ":supported_features_header", + ":gtest_utils", "//executorch/runtime/core/exec_aten:lib" + aten_suffix, "//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix, "//executorch/runtime/kernel:kernel_includes", diff --git a/test/build_size_test.sh b/test/build_size_test.sh index 428e351cf0..d970a023b5 100644 --- a/test/build_size_test.sh +++ b/test/build_size_test.sh @@ -13,7 +13,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/../.ci/scripts/utils.sh" cmake_install_executorch_lib() { echo "Installing libexecutorch.a" - rm -rf cmake-out + clean_executorch_install_folders retry cmake -DBUCK2="$BUCK2" \ -DCMAKE_CXX_STANDARD_REQUIRED=ON \