diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index ad2f713466..fabb60b5fd 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -41,7 +41,7 @@ case "${IMAGE_NAME}" in LINTRUNNER="" CLANG_VERSION=12 # From https://developer.android.com/ndk/downloads - ANDROID_NDK_VERSION=r26c + ANDROID_NDK_VERSION=r27b ;; *) echo "Invalid image name ${IMAGE_NAME}" diff --git a/.github/workflows/android-release-artifacts.yml b/.github/workflows/android-release-artifacts.yml index 3983885848..a10de79363 100644 --- a/.github/workflows/android-release-artifacts.yml +++ b/.github/workflows/android-release-artifacts.yml @@ -13,8 +13,24 @@ concurrency: cancel-in-progress: true jobs: + check-if-aar-exists: + name: check-if-aar-exists + runs-on: ubuntu-22.04 + timeout-minutes: 10 + steps: + - name: Check if this RC version is already in S3 + shell: bash + run: | + VERSION="${{ inputs.version }}" + if curl -I "https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" | grep "200 OK"; then + echo "AAR already exists at https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" + echo "Will skip build/upload" + exit 1 + fi + build-aar: name: build-aar + needs: check-if-aar-exists uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.2xlarge diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 559e28b441..40c70cc086 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -53,7 +53,7 @@ jobs: # NB: Use metal install for KVM support to run the emulator faster runs-on: linux.24xl.spr-metal env: - ANDROID_NDK_VERSION: r26c + ANDROID_NDK_VERSION: r27b API_LEVEL: 34 steps: - name: Setup SSH (Click me for login details) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e0566438b7..c4e806a842 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -19,6 +19,9 @@ ConvertSplitToSlicePass, ) from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( + InsertSqueezeAfterSumPass, +) from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, ) @@ -47,6 +50,7 @@ def transform_to_backend_pipeline( self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeDivPass()) + self.add_pass(InsertSqueezeAfterSumPass()) self.add_pass(ConvertSplitToSlicePass()) for spec in compile_spec: if spec.key == "permute_memory_format": diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py new file mode 100644 index 0000000000..152d5c95f6 --- /dev/null +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -0,0 +1,69 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +import torch +import torch.fx +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair + +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class InsertSqueezeAfterSumPass(ExportPass): + """ + In Pytorch, the default behaviour of Tensor.sum is to squeeze + the dimension that is summed (keep_dim = False). + However, in TOSA, REDUCE_SUM always preserves the + rank of the input (keep_dim = True). + To get a 1-1 mapping in the sum lowering, normalize the + keep_dim = False case to keep_dim = True and add squeeze ops. + + Original: + sum(dims, keep_dim = False) + After pass: + sum(dims, keep_dim = True) + (q) + (dq) + squeeze(dim = dims) + """ + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.sum.dim_IntList: + continue + sum_node = cast(torch.fx.Node, node) + keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False) + if keep_dim: + continue + + dim_list = cast(list[int], sum_node.args[1]) + quantized = is_quant_node(sum_node) + if quantized: + qparams = get_quant_node_args(sum_node.all_input_nodes[0]) + qparams = qparams + (torch.int8,) + else: + qparams = None + + # Add keep_dim = True arg to sum node. + sum_node.args = sum_node.args[0:2] + (True,) + + with graph_module.graph.inserting_after(sum_node): + squeeze_node = create_node( + graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, () + ) + sum_node.replace_all_uses_with(squeeze_node) + squeeze_node.args = (sum_node, dim_list) + if quantized: + sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 22fb5ac6ac..7db893694b 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.mean.dim, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 6d08290f03..855487cf7f 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -34,6 +34,7 @@ op_softmax, op_squeeze, op_sub, + op_sum, op_unsqueeze, op_view, ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py new file mode 100644 index 0000000000..b67f5f92db --- /dev/null +++ b/backends/arm/operators/op_sum.py @@ -0,0 +1,96 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import executorch.backends.arm.tosa_utils as tutils + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class AddVisitor(NodeVisitor): + target = "aten.sum.dim_IntList" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_node = inputs[0] + input_shape = list(input_node.shape) + dim_list = cast(list[int], inputs[1].special) + dim_list = [dim % len(input_node.shape) for dim in dim_list] + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + + if is_quant_node: + + # Rescale input to 32 bit + rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( + [node.all_input_nodes[0]], tosa_graph + ) + + prev_node = rescaled_inputs[0] + reduced_shape = input_shape + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1. + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(input_node.dim_order.index(dim)) + + next_node = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, input_node.dim_order), + dtype=ts.DType.INT32, + ) + + tosa_graph.addOperator( + TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr + ) + + prev_node = next_node + tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph) + else: + input_name = input_node.name + reduced_shape = input_shape + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1 + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(input_node.dim_order.index(dim)) + + if dim == dim_list[-1]: + output_name = output.name + else: + output_name = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, input_node.dim_order), + dtype=ts.DType.FP32, + ).name + + tosa_graph.addOperator( + TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr + ) + + input_name = output_name diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index b5fe4b76eb..6a68eb2eb9 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -267,11 +267,11 @@ class ArmQuantizer(Quantizer): "add", "sub", "mul", - "sigmoid", "mm", "cat", "one_to_one", "generic", + "sum", ] def __init__(self) -> None: diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 594911075f..bc3184298f 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -59,6 +59,6 @@ def decorator(annotator: AnnotatorType): mm_annotator, mul_annotator, one_to_one_annotator, - sigmoid_annotator, sub_annotator, + sum_annotator, ) diff --git a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py index b5e4afa2e1..e2a1398018 100644 --- a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py @@ -40,6 +40,7 @@ def _annotate_one_to_one( torch.ops.aten.log.default, torch.ops.aten.reciprocal.default, torch.ops.aten.rsqrt.default, + torch.ops.aten.sigmoid.default, ) for node in gm.graph.nodes: if node.op != "call_function" or node.target not in one_to_one_ops: diff --git a/backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py b/backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py deleted file mode 100644 index 3d24269483..0000000000 --- a/backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -@register_annotator("sigmoid") -def _annotate_sigmoid( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - - # input/ output range of sigmoid is always same -> quantize with fixed qspec. - # this configuration maps input: (-128, 127) -> (-6.0, 5.95). Outside these bounds, sigmoid ~= const. - # output: (-1,0.99) -> (-128, 127). Sigmoid has output value range (-1,1) - # Note that this exact choice is somewhat arbitrary. - - input_act_qspec = quantization_config.get_fixed_qspec(scale=6 / 128, zp=0) - output_act_qspec = quantization_config.get_fixed_qspec(scale=1 / 128, zp=0) - - for node in gm.graph.nodes: - if node.op != "call_function" or node.target != torch.ops.aten.sigmoid.default: - continue - if filter_fn and not filter_fn(node): - continue - input_node = node.args[0] - - if not arm_quantizer_utils.is_annotated(node): - _annotate_input_qspec_map( - node, - input_node, - input_act_qspec, - ) - _annotate_output_qspec(node, output_act_qspec) - - arm_quantizer_utils.mark_nodes_as_annotated([node]) - annotated_partitions.append([node]) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/sum_annotator.py b/backends/arm/quantizer/quantization_annotation/sum_annotator.py new file mode 100644 index 0000000000..1c0399dc34 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/sum_annotator.py @@ -0,0 +1,57 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, cast, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig + +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.fx import Node + + +@register_annotator("sum") +def _annotate_sum( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.target is not torch.ops.aten.sum.dim_IntList: + continue + if filter_fn and not filter_fn(node): + continue + + sum_node = node + if arm_quantizer_utils.is_annotated(sum_node): + continue + + input_act = sum_node.args[0] + + if not isinstance(input_act, Node): + continue + if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm): + continue + + input_act_qspec = cast( + Optional[QuantizationSpecBase], quantization_config.get_input_act_qspec() + ) + input_qspec_map = {input_act: input_act_qspec} + shared_with_input0_qspec = SharedQuantizationSpec((input_act, sum_node)) + + sum_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_with_input0_qspec, + _annotated=True, + ) + annotated_partitions.append([sum_node]) + return annotated_partitions diff --git a/backends/arm/runtime/ArmBackendEthosU.cpp b/backends/arm/runtime/ArmBackendEthosU.cpp index b0452fb9e7..99ce0a9df2 100644 --- a/backends/arm/runtime/ArmBackendEthosU.cpp +++ b/backends/arm/runtime/ArmBackendEthosU.cpp @@ -115,7 +115,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { ArmBackendExecuteCallbacks ArmBackend_execute_callbacks; // Command stream - we know at this point it's aligned char* data = (char*)execution_handle->processed->data(); - ET_LOG(Info, "ArmBackend::execute %p", data); + ET_LOG(Debug, "ArmBackend::execute %p", data); // Read key sections from the vela_bin_stream if (vela_bin_read(data, &handles, execution_handle->processed->size()) == @@ -295,7 +295,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { tensor.size(1) == io->shape[3] && tensor.size(2) == io->shape[1] && tensor.size(3) == io->shape[2]; if (permuted_shape) { - ET_LOG(Info, "Tensor input/output %d will be permuted", index); + ET_LOG(Debug, "Tensor input/output %d will be permuted", index); } if (permuted_io_flag != permuted_shape) { ET_LOG( diff --git a/backends/arm/test/misc/test_model_evaluator.py b/backends/arm/test/misc/test_model_evaluator.py new file mode 100644 index 0000000000..ee7d001115 --- /dev/null +++ b/backends/arm/test/misc/test_model_evaluator.py @@ -0,0 +1,67 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random +import tempfile +import unittest + +import torch +from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator + +random.seed(0) + +# Create an input that is hard to compress +COMPRESSION_RATIO_TEST = bytearray(random.getrandbits(8) for _ in range(1000000)) + + +def mocked_model_1(input: torch.Tensor) -> torch.Tensor: + return torch.tensor([1.0, 2.0, 3.0, 4.0]) + + +def mocked_model_2(input: torch.Tensor) -> torch.Tensor: + return torch.tensor([1.0, 2.0, 3.0, 3.0]) + + +class TestGenericModelEvaluator(unittest.TestCase): + """Tests the GenericModelEvaluator class.""" + + def test_get_model_error(self): + example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + evaluator = GenericModelEvaluator( + "dummy_model", + mocked_model_1, + mocked_model_2, + example_input, + "tmp/output_tag0.tosa", + ) + max_error, max_absolute_error, max_percentage_error, mae = ( + evaluator.get_model_error() + ) + + self.assertEqual(max_error, 1.0) + self.assertEqual(max_absolute_error, 1.0) + self.assertEqual(max_percentage_error, 25.0) + self.assertEqual(mae, 0.25) + + def test_get_compression_ratio(self): + with tempfile.NamedTemporaryFile(delete=True) as temp_bin: + temp_bin.write(COMPRESSION_RATIO_TEST) + + # As the size of the file is quite small we need to call flush() + temp_bin.flush() + temp_bin_name = temp_bin.name + + example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + evaluator = GenericModelEvaluator( + "dummy_model", + mocked_model_1, + mocked_model_2, + example_input, + temp_bin_name, + ) + + ratio = evaluator.get_compression_ratio() + self.assertAlmostEqual(ratio, 1.0, places=2) diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index ddc29a66a4..4d126b68e5 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -155,18 +155,12 @@ def test_sigmoid_tosa_BI(self, test_name: str, test_data: torch.Tensor): def test_add_sigmoid_tosa_MI(self): self._test_sigmoid_tosa_MI_pipeline(self.AddSigmoid(), (test_data_suite[0][1],)) - @unittest.skip( - reason="Started to fails when PyTorch 2.5->2.6 https://github.com/pytorch/executorch/issues/5832" - ) def test_add_sigmoid_tosa_BI(self): - self._test_sigmoid_tosa_BI_pipeline(self.AddSigmoid(), (test_data_suite[0][1],)) + self._test_sigmoid_tosa_BI_pipeline(self.AddSigmoid(), (test_data_suite[5][1],)) def test_sigmoid_add_tosa_MI(self): self._test_sigmoid_tosa_MI_pipeline(self.SigmoidAdd(), (test_data_suite[0][1],)) - @unittest.skip( - reason="Started to fails when PyTorch 2.5->2.6 https://github.com/pytorch/executorch/issues/5832" - ) def test_sigmoid_add_tosa_BI(self): self._test_sigmoid_tosa_BI_pipeline(self.SigmoidAdd(), (test_data_suite[0][1],)) @@ -175,9 +169,6 @@ def test_sigmoid_add_sigmoid_tosa_MI(self): self.SigmoidAddSigmoid(), (test_data_suite[4][1], test_data_suite[3][1]) ) - @unittest.skip( - reason="Started to fails when PyTorch 2.5->2.6 https://github.com/pytorch/executorch/issues/5832" - ) def test_sigmoid_add_sigmoid_tosa_BI(self): self._test_sigmoid_tosa_BI_pipeline( self.SigmoidAddSigmoid(), (test_data_suite[4][1], test_data_suite[3][1]) diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py new file mode 100644 index 0000000000..73860dfa4a --- /dev/null +++ b/backends/arm/test/ops/test_sum.py @@ -0,0 +1,129 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +exampledata_t = Tuple[torch.Tensor, int | list[int], bool] +"""(data, dim(s), keepdim)""" + + +class TestSum(unittest.TestCase): + """Tests sum which sums all elements along some specified dimensions. + keepdim specifies whether the dimension that is summed should + be squeezed or not. + """ + + class Sum(torch.nn.Module): + test_parameters: list[Tuple[exampledata_t]] = [ + ((torch.rand(10), 0, True),), + ((torch.rand(10, 10), 1, False),), + ((torch.rand(10, 10, 10), [-3, 1], True),), + ((torch.rand(2, 1, 5, 8), 1, False),), + ((torch.rand(1, 2, 3, 4), 3, True),), + ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), + ] + + def forward(self, x: torch.Tensor, dim: int, keepdim: bool): + return x.sum(dim=dim, keepdim=keepdim) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_sum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: tuple[exampledata_t] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_sum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: tuple[exampledata_t] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_sum_ethosu_BI_pipeline( + self, + module: torch.nn.Module, + test_data: tuple[exampledata_t], + compile_spec: CompileSpec, + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + + @parameterized.expand(Sum.test_parameters) + def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]): + self._test_sum_tosa_MI_pipeline(self.Sum(), test_data) + + @parameterized.expand(Sum.test_parameters) + def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]): + self._test_sum_tosa_BI_pipeline(self.Sum(), test_data) + + @parameterized.expand(Sum.test_parameters) + def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): + self._test_sum_ethosu_BI_pipeline( + self.Sum(), + test_data, + common.get_u55_compile_spec(permute_memory_to_nhwc=False), + ) + + @parameterized.expand(Sum.test_parameters) + def test_sum_u85_BI(self, test_data: tuple[exampledata_t]): + self._test_sum_ethosu_BI_pipeline( + self.Sum(), + test_data, + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 5ca571f26d..167b99a328 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -439,6 +439,9 @@ def run_tosa_ref_model( if self.is_quantized: # Need to dequant back to FP32 for comparison with torch output + # Convert to int32 prior to dequantize the output + if tosa_ref_output.dtype == np.int8: + tosa_ref_output = tosa_ref_output.astype(np.int32) quant_param = self.qp_output assert ( quant_param is not None diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py new file mode 100644 index 0000000000..4ffb80c2f0 --- /dev/null +++ b/backends/arm/util/arm_model_evaluator.py @@ -0,0 +1,89 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile +import zipfile +from typing import Optional, Tuple, Union + +import torch + + +class GenericModelEvaluator: + def __init__( + self, + model_name: str, + fp32_model: torch.nn.Module, + int8_model: torch.nn.Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: Optional[str], + ) -> None: + self.model_name = model_name + + self.fp32_model = fp32_model + self.int8_model = int8_model + self.example_input = example_input + + if tosa_output_path: + self.tosa_output_path = tosa_output_path + else: + self.tosa_output_path = None + + def get_model_error(self) -> Union[float, float, float, float]: + """ + Returns the following metrics between the outputs of the FP32 and INT8 model: + - Maximum error + - Maximum absolute error + - Maximum percentage error + - Mean absolute error + """ + fp32_output = self.fp32_model(*self.example_input) + int8_output = self.int8_model(*self.example_input) + + difference = fp32_output - int8_output + percentage_error = torch.div(difference, fp32_output) * 100 + + max_error = torch.max(difference).item() + max_absolute_error = torch.max(torch.abs(difference)).item() + max_percentage_error = torch.max(percentage_error).item() + mean_absolute_error = torch.mean(torch.abs(difference).float()).item() + + return max_error, max_absolute_error, max_percentage_error, mean_absolute_error + + def get_compression_ratio(self) -> float: + """Compute the compression ratio of the outputted TOSA flatbuffer.""" + with tempfile.NamedTemporaryFile(delete=True, suffix=".zip") as temp_zip: + with zipfile.ZipFile( + temp_zip.name, "w", compression=zipfile.ZIP_DEFLATED + ) as f: + f.write(self.tosa_output_path) + + compression_ratio = os.path.getsize( + self.tosa_output_path + ) / os.path.getsize(temp_zip.name) + + return compression_ratio + + def evaluate(self) -> dict[any]: + max_error, max_absolute_error, max_percent_error, mean_absolute_error = ( + self.get_model_error() + ) + output_metrics = { + "name": self.model_name, + "metrics": { + "max_error": max_error, + "max_absolute_error": max_absolute_error, + "max_percentage_error": max_percent_error, + "mean_absolute_error": mean_absolute_error, + }, + } + + if self.tosa_output_path: + output_metrics["metrics"][ + "compression_ratio" + ] = self.get_compression_ratio() + + return output_metrics diff --git a/backends/qualcomm/README.md b/backends/qualcomm/README.md index d426aff16b..a0cb5a5a50 100644 --- a/backends/qualcomm/README.md +++ b/backends/qualcomm/README.md @@ -21,17 +21,8 @@ Please check `generate_qnn_executorch_compiler_spec()` in - Snapdragon 8 Gen 2 - Snapdragon 8 Gen 3 -### How to add more supported Chipset - -#### Step 1: Check SoC model of snapdragon device -Get SoC model which would like to be supported from the document of Qualcomm AI Engine Direct SDK. - -#### Step 2: Update schema of compiler option and SoC information in serialization -Add SoC model into QcomChipset enum in [schema](./serialization/schema.fbs) and [qnn_compile_spec_schema](./serialization/qnn_compile_spec_schema.py). -Insert new SoC information into _soc_info_table in [qnn_compile_spec_schema](./serialization/qnn_compile_spec_schema.py). - -#### Step 3: Recompile the .pte file -Follow [setup](../../docs/source/build-run-qualcomm-ai-engine-direct-backend.md) to setup environment and build runtime with new schema header. +### Adding more supported Chipset +Currently, users cannot add additional chipset models because the chipset ID is not accessible to community users. If you have specific chipset models you wish to add, please contact one of the authors in the `Code Reviews` section at the bottom of this page. ### Supported Inference Type - Quantized diff --git a/backends/qualcomm/runtime/targets.bzl b/backends/qualcomm/runtime/targets.bzl index f3b868892f..8c921b96ec 100644 --- a/backends/qualcomm/runtime/targets.bzl +++ b/backends/qualcomm/runtime/targets.bzl @@ -3,6 +3,7 @@ load( "ANDROID", ) load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//xplat/executorch/backends/qualcomm/qnn_version.bzl", "get_qnn_library_verision") def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -56,10 +57,10 @@ def define_common_targets(): platforms = [ANDROID], visibility = ["@EXECUTORCH_CLIENTS"], resources = { - "qnn_lib": "fbsource//third-party/qualcomm/qnn/qnn-2.25:qnn_offline_compile_libs", + "qnn_lib": "fbsource//third-party/qualcomm/qnn/qnn-{0}:qnn_offline_compile_libs".format(get_qnn_library_verision()), }, deps = [ - "fbsource//third-party/qualcomm/qnn:api", + "fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_verision()), ":logging", "//executorch/backends/qualcomm:schema", "//executorch/backends/qualcomm/aot/ir:qcir_utils", diff --git a/backends/qualcomm/targets.bzl b/backends/qualcomm/targets.bzl index 1435d41f8d..a201274402 100644 --- a/backends/qualcomm/targets.bzl +++ b/backends/qualcomm/targets.bzl @@ -92,3 +92,6 @@ def define_common_targets(): ":schema", ], ) + +def get_qnn_library_verision(): + return "2.26" diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 2e1bd0eff3..01b1014e4c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -68,7 +68,7 @@ def setUp(self): TestQNN.rtol = 1e-1 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, debug=False, saver=False, @@ -522,7 +522,7 @@ def setUp(self): TestQNN.rtol = 1e-1 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, debug=False, saver=False, @@ -674,7 +674,7 @@ def setUp(self): TestQNN.rtol = 1 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, debug=False, saver=False, @@ -1236,7 +1236,7 @@ def setUp(self): TestQNN.rtol = 1 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, debug=False, saver=False, @@ -1444,7 +1444,7 @@ def setUp(self): TestQNN.rtol = 1e-1 backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, debug=False, saver=False, @@ -1453,7 +1453,7 @@ def setUp(self): def test_qnn_backend_dump_intermediate_outputs(self): backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, dump_intermediate_outputs=True, ) @@ -1498,7 +1498,7 @@ def test_qnn_backend_multi_contexts(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) partitioner = QnnPartitioner(compiler_specs) @@ -1514,7 +1514,7 @@ def test_qnn_backend_multi_contexts_composite(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) module = CompositeDelegateModule( # noqa: F405 @@ -1535,7 +1535,7 @@ def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, profile=True, ) @@ -1554,7 +1554,7 @@ def test_qnn_backend_shared_buffer(self): use_fp16=True, ) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, shared_buffer=True, ) @@ -1569,7 +1569,7 @@ def test_qnn_backend_shared_buffer(self): def test_qnn_backend_online_prepare(self): backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, online_prepare=True, ) @@ -1590,7 +1590,7 @@ def test_qnn_backend_context_direct(self): bundle_program = from_context_binary(ctx_path, "ctx_loader") backend_options = generate_htp_compiler_spec(use_fp16=True) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, is_from_context_binary=True, ) @@ -1614,7 +1614,7 @@ def setUp(self): TestQNN.rtol = 1 backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, debug=False, saver=False, @@ -1623,7 +1623,7 @@ def setUp(self): def test_qnn_backend_dump_intermediate_outputs(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, dump_intermediate_outputs=True, ) @@ -1657,7 +1657,7 @@ def test_qnn_backend_skip_node_id_quantizer(self): use_fp16=False, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) partitioner = QnnPartitioner(compiler_specs) @@ -1704,7 +1704,7 @@ def test_qnn_backend_skip_node_op_quantizer(self): use_fp16=False, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) partitioner = QnnPartitioner(compiler_specs) @@ -1740,7 +1740,7 @@ def test_qnn_backend_graph_level_mixed_precision(self): use_fp16=False, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) partitioner = QnnPartitioner(compiler_specs) @@ -1781,7 +1781,7 @@ def test_qnn_backend_multi_contexts(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) partitioner = QnnPartitioner(compiler_specs) @@ -1797,7 +1797,7 @@ def test_qnn_backend_multi_contexts_composite(self): use_multi_contexts=True, ) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) module = CompositeDelegateModule( # noqa: F405 @@ -1819,7 +1819,7 @@ def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, profile=True, ) @@ -1839,7 +1839,7 @@ def test_qnn_backend_shared_buffer(self): use_fp16=False, ) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, shared_buffer=True, ) @@ -1855,7 +1855,7 @@ def test_qnn_backend_shared_buffer(self): def test_qnn_backend_online_prepare(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, online_prepare=True, ) @@ -1877,7 +1877,7 @@ def test_qnn_backend_context_direct(self): bundle_program = from_context_binary(ctx_path, "ctx_loader") backend_options = generate_htp_compiler_spec(use_fp16=False) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=self.arch_table[TestQNN.model], + soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, is_from_context_binary=True, ) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index c58da42e84..4e5ed04bc5 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -120,7 +120,7 @@ class TestQNN(unittest.TestCase): build_folder: Literal = "" model: QcomChipset = None compiler_specs: List[CompileSpec] = None - arch_table = get_soc_to_chipset_map() + chipset_table = get_soc_to_chipset_map() error_only = False ip = "localhost" port = 8080 diff --git a/backends/vulkan/README.md b/backends/vulkan/README.md index f520e9891b..b428333c91 100644 --- a/backends/vulkan/README.md +++ b/backends/vulkan/README.md @@ -150,7 +150,7 @@ when building with CMake. First, make sure that you have the Android NDK installed; any NDK version past NDK r19c should work. Note that the examples in this doc have been validated with -NDK r25. The Android SDK should also be installed so that you have access to `adb`. +NDK r27b. The Android SDK should also be installed so that you have access to `adb`. The instructions in this page assumes that the following environment variables are set. diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 2bfd8b44a7..109a61049d 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -144,9 +144,24 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: return False - def is_valid_to_copy(self, node: torch.fx.node) -> bool: # pyre-ignore[11] - # lower only if floating point dtype conversion - return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16) + def is_valid_to_copy(self, node: torch.fx.Node) -> bool: + float_dtypes = [torch.float16, torch.float32] + + if len(node.args) != 1: + return False + + in_arg = node.args[0] + if not isinstance(in_arg, torch.fx.Node): + return False + + in_tensor = in_arg.meta.get("val", None) + out_tensor = node.meta.get("val", None) + + if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor): + if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes: + return True + + return False def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node @@ -174,13 +189,13 @@ def _is_node_supported( if target not in VulkanSupportedOperators._ops: return False - features = VulkanSupportedOperators._ops[target] - if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy( node ): return False + features = VulkanSupportedOperators._ops[target] + if self.require_dynamic_shapes and not features.supports_dynamic_shape: return False diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 559c70bd8b..19737f5718 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -600,7 +600,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } #ifdef ET_EVENT_TRACER_ENABLED - EventTracer* event_tracer = context.event_tracer(); + runtime::EventTracer* event_tracer = context.event_tracer(); compute_graph->context()->querypool().extract_results(); for (const auto& tup : compute_graph->context()->querypool().get_shader_timestamp_data()) { diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 206a4eafa3..838605f05f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -71,6 +71,17 @@ void add_q_8w_linear_node( const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); + } ValueRef q_mat2 = prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked); ValueRef scales = @@ -78,39 +89,45 @@ void add_q_8w_linear_node( std::string kernel_name = "q_8w_linear"; kernel_name.reserve(kShaderNameReserve); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); vkapi::ParamsBindList ubos({}); - if (graph.is_buffer_storage(out)) { + if (graph.is_buffer_storage(out_W_packed)) { ubos.append( - {graph.sizes_ubo(out), - graph.strides_ubo(out), - graph.numel_ubo(out), - graph.sizes_ubo(mat1), + {graph.sizes_ubo(out_W_packed), + graph.strides_ubo(out_W_packed), + graph.numel_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed), graph.strides_ubo(mat1), graph.strides_ubo(q_mat2), graph.strides_ubo(scales)}); } else { - ubos.append({graph.logical_limits_ubo(out), graph.sizes_ubo(mat1)}); + ubos.append( + {graph.logical_limits_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed)}); } graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + graph.create_global_wg_size(out_W_packed), + graph.create_local_wg_size(out_W_packed), // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, // Shader params buffers ubos, // Specialization Constants {}, // Resizing Logic resize_qlinear_node)); + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(out) != WHCN::kWidthDim) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } } void weight_int8pack_mm( diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 681c442afc..f015b08e61 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -26,7 +26,6 @@ build_android_native_library() { EXECUTORCH_BUILD_QNN=OFF fi - cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="${ANDROID_ABI}" \ @@ -150,19 +149,27 @@ collect_artifacts_to_be_uploaded() { cp extension/benchmark/android/benchmark/app/build/outputs/apk/androidTest/debug/*.apk "${MINIBENCH_APP_DIR}" } -BUILD_AAR_DIR="$(mktemp -d)" -export BUILD_AAR_DIR -if [ -z "$ANDROID_ABIS" ]; then - ANDROID_ABIS=("arm64-v8a" "x86_64") -fi -export ANDROID_ABIS +main() { + BUILD_AAR_DIR="$(mktemp -d)" + export BUILD_AAR_DIR + if [ -z "$ANDROID_ABIS" ]; then + ANDROID_ABIS=("arm64-v8a" "x86_64") + fi + export ANDROID_ABIS + + ARTIFACTS_DIR_NAME="$1" -ARTIFACTS_DIR_NAME="$1" + build_jar + for ANDROID_ABI in "${ANDROID_ABIS[@]}"; do + build_android_native_library ${ANDROID_ABI} + done + build_aar + build_android_demo_apps + if [ -n "$ARTIFACTS_DIR_NAME" ]; then + collect_artifacts_to_be_uploaded ${ARTIFACTS_DIR_NAME} + fi +} -build_jar -for ANDROID_ABI in "${ANDROID_ABIS[@]}"; do - build_android_native_library ${ANDROID_ABI} -done -build_aar -build_android_demo_apps -collect_artifacts_to_be_uploaded ${ARTIFACTS_DIR_NAME} +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + main "$@" +fi diff --git a/docs/source/export-to-executorch-api-reference.rst b/docs/source/export-to-executorch-api-reference.rst index 5560e75e21..1ae563d842 100644 --- a/docs/source/export-to-executorch-api-reference.rst +++ b/docs/source/export-to-executorch-api-reference.rst @@ -6,6 +6,9 @@ For detailed information on how APIs evolve and the deprecation process, please .. automodule:: executorch.exir .. autofunction:: to_edge +.. automodule:: executorch.exir +.. autofunction:: to_edge_transform_and_lower + .. autoclass:: EdgeProgramManager :members: methods, config_methods, exported_program, transform, to_backend, to_executorch diff --git a/docs/source/getting-started-setup.md b/docs/source/getting-started-setup.md index e8a2b4530d..f2f98d6563 100644 --- a/docs/source/getting-started-setup.md +++ b/docs/source/getting-started-setup.md @@ -80,7 +80,9 @@ Alternatively, if you would like to experiment with ExecuTorch quickly and easil ```bash # Clone the ExecuTorch repo from GitHub - git clone https://github.com/pytorch/executorch.git + # 'main' branch is the primary development branch where you see the latest changes. + # 'viable/strict' contains all of the commits on main that pass all of the necessary CI checks. + git clone --branch viable/strict https://github.com/pytorch/executorch.git cd executorch # Update and pull submodules diff --git a/docs/source/index.rst b/docs/source/index.rst index 1e1060f70b..095489de35 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -4,11 +4,12 @@ Welcome to the ExecuTorch Documentation ======================================= .. important:: - This is an alpha release; the ExecuTorch APIs and the ``.pte`` binary format - may change in incompatible ways before stabilizing in a future beta release. - When deploying models, we currently recommend using a version of the runtime - built from the same git revision that was used to generate the ``.pte`` file. - Once the format has stabilized, this will no longer be necessary. + This is a beta release. As of this ExecuTorch beta release, the API + will follow the `lifecycle and deprecation policy `__ + and ``.pte`` binary format will comply with the *runtime compatibility policy* (TODO: add link). + This ensures that application developers can update to the latest version of ExecuTorch + without breaking existing integration code, in accordance with these policies. + If any issues arise or compatibility breaks occur, please `report them in GitHub `__. We welcome any feedback, suggestions, and bug reports from the community to help us improve the technology. Please use the `PyTorch Forums @@ -55,13 +56,12 @@ Topics in this section will help you get started with ExecuTorch. ExecuTorch. .. grid-item-card:: :octicon:`file-code;1em` - ExecuTorch Intermediate Representation API + ExecuTorch Llama :img-top: _static/img/card-background.svg - :link: ir-exir.html + :link: llm/llama.html :link-type: url - Learn about EXIR, a graph-based intermediate - representation (IR) of PyTorch programs. + Learn about running Llama models via ExecuTorch .. toctree:: :glob: @@ -117,10 +117,11 @@ Topics in this section will help you get started with ExecuTorch. :caption: Working with LLMs :hidden: - llm/getting-started - llm/llama-demo-android - llm/build-run-llama3-qualcomm-ai-engine-direct-backend - llm/llama-demo-ios + Llama + Llama on Android + Llama on iOS + Llama on Android via Qualcomm backend + Intro to LLMs in Executorch .. toctree:: :glob: diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index 2cdf13ca65..4cfebbf9e6 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -1,4 +1,4 @@ -# Getting Started with LLMs via ExecuTorch +# Intro to LLMs in Executorch Welcome to LLM Manual! This manual is designed to provide a practical example to leverage ExecuTorch in onboarding your own Large Language Models (LLMs). Our primary goal is to offer @@ -13,6 +13,8 @@ We encourage users to use this project as a starting point and adapt it to their which includes creating your own versions of the tokenizer, sampler, acceleration backends, and other components. We hope this project serves as a useful guide in your journey with LLMs and ExecuTorch. +For deploying Llama with optimal performance, please see [Llama guide](./llama.md). + ### Table Of Contents diff --git a/docs/source/llm/llama.md b/docs/source/llm/llama.md new file mode 100644 index 0000000000..2d266ba7ae --- /dev/null +++ b/docs/source/llm/llama.md @@ -0,0 +1,5 @@ +# Llama on ExecuTorch + +See +[Llama readme](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md) +for detailed information about running Llama on ExecuTorch. diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index 666ee23aa3..4f0ba3bd1a 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -25,7 +25,7 @@ import torchvision.models as models from torch.export import export, ExportedProgram from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge +from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower from executorch.exir.backend.backend_api import to_backend @@ -33,9 +33,10 @@ mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFA sample_inputs = (torch.randn(1, 3, 224, 224), ) exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs) -edge: EdgeProgramManager = to_edge(exported_program) - -edge = edge.to_backend(XnnpackPartitioner()) +edge: EdgeProgramManager = to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], +) ``` We will go through this example with the [MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate. @@ -47,16 +48,18 @@ GraphModule( (lowered_module_1): LoweredBackendModule() ) -def forward(self, arg314_1): + + +def forward(self, b_features_0_1_num_batches_tracked, ..., x): lowered_module_0 = self.lowered_module_0 - executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, arg314_1); lowered_module_0 = arg314_1 = None - getitem = executorch_call_delegate[0]; executorch_call_delegate = None - aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]); getitem = None - aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None lowered_module_1 = self.lowered_module_1 - executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, aten_clone_default); lowered_module_1 = aten_clone_default = None - getitem_1 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None - return (getitem_1,) + executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, x); lowered_module_1 = x = None + getitem_53 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None + aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem_53, [1, 1280]); getitem_53 = None + aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None + executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_clone_default); lowered_module_0 = aten_clone_default = None + getitem_52 = executorch_call_delegate[0]; executorch_call_delegate = None + return (getitem_52,) ``` We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`. @@ -75,7 +78,7 @@ The XNNPACK delegate can also execute symmetrically quantized models. To underst ```python from torch.export import export_for_training -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) @@ -111,9 +114,11 @@ Quantization requires a two stage export. First we use the `export_for_training` ```python # Continued from earlier... -edge = to_edge(export(quantized_mobilenetv2, sample_inputs), compile_config=EdgeCompileConfig(_check_ir_validity=False)) - -edge = edge.to_backend(XnnpackPartitioner()) +edge = to_edge_transform_and_lower( + export(quantized_mobilenetv2, sample_inputs), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + partitioner=[XnnpackPartitioner()] +) exec_prog = edge.to_executorch() diff --git a/examples/apple/coreml/scripts/export.py b/examples/apple/coreml/scripts/export.py index 8aecfdabec..a83f1695f6 100644 --- a/examples/apple/coreml/scripts/export.py +++ b/examples/apple/coreml/scripts/export.py @@ -14,8 +14,10 @@ import torch +# pyre-fixme[21]: Could not find module `executorch.backends.apple.coreml.compiler`. from executorch.backends.apple.coreml.compiler import CoreMLBackend +# pyre-fixme[21]: Could not find module `executorch.backends.apple.coreml.partition`. from executorch.backends.apple.coreml.partition import CoreMLPartitioner from executorch.devtools.etrecord import generate_etrecord from executorch.exir import to_edge @@ -76,6 +78,7 @@ def parse_args() -> argparse.ArgumentParser: parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction) args = parser.parse_args() + # pyre-fixme[7]: Expected `ArgumentParser` but got `Namespace`. return args diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 2f06076aad..29cc0c30c7 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -8,18 +8,21 @@ # Example script for exporting simple models to flatbuffer import argparse +import json import logging import os -from typing import Optional -import torch +from pathlib import Path +from typing import Optional, Tuple +import torch from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.arm_partitioner import ArmPartitioner from executorch.backends.arm.quantizer.arm_quantizer import ( ArmQuantizer, get_symmetric_quantization_config, ) +from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig @@ -151,6 +154,8 @@ def forward(self, x): "softmax": SoftmaxModule, } +evaluators = {} + targets = [ "ethos-u55-32", "ethos-u55-64", @@ -202,6 +207,37 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder: return spec_builder.build() +def get_evaluator(model_name: str) -> GenericModelEvaluator: + if model_name not in evaluators: + return GenericModelEvaluator + else: + return evaluators[model_name] + + +def evaluate_model( + model_name: str, + intermediates: str, + model_fp32: torch.nn.Module, + model_int8: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], +): + evaluator = get_evaluator(model_name) + + # Get the path of the TOSA flatbuffer that is dumped + intermediates_path = Path(intermediates) + tosa_paths = list(intermediates_path.glob("*.tosa")) + + init_evaluator = evaluator( + model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0]) + ) + + quant_metrics = init_evaluator.evaluate() + output_json_path = intermediates_path / "quant_metrics.json" + + with output_json_path.open("w") as json_file: + json.dump(quant_metrics, json_file) + + def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None): graph_module = edge.exported_program().graph_module delegation_info = get_delegation_info(graph_module) @@ -242,6 +278,14 @@ def get_args(): choices=targets, help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}", ) + parser.add_argument( + "-e", + "--evaluate", + action="store_true", + required=False, + default=False, + help="Flag for running evaluation of the model.", + ) parser.add_argument( "-q", "--quantize", @@ -275,11 +319,11 @@ def get_args(): help="Location for outputs, if not the default of cwd.", ) args = parser.parse_args() - return args - -if __name__ == "__main__": - args = get_args() + if args.evaluate and (args.quantize is None or args.intermediates is None): + raise RuntimeError( + "--evaluate requires --quantize and --intermediates to be enabled." + ) if args.debug: logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True) @@ -302,16 +346,26 @@ def get_args(): ): raise RuntimeError(f"Model {args.model_name} cannot be delegated.") + return args + + +if __name__ == "__main__": + args = get_args() + # Pick model from one of the supported lists model, example_inputs = get_model_and_inputs_from_name(args.model_name) model = model.eval() + model_fp32 = model + # pre-autograd export. eventually this will become torch.export model = torch.export.export_for_training(model, example_inputs).module() # Quantize if required + model_int8 = None if args.quantize: model = quantize(model, example_inputs) + model_int8 = model edge = export_to_edge( model, @@ -361,3 +415,8 @@ def get_args(): output_name = os.path.join(args.output, output_name) save_pte_program(exec_prog, output_name) + + if args.evaluate: + evaluate_model( + args.model_name, args.intermediates, model_fp32, model_int8, example_inputs + ) diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index ef4a10289b..ae335208cd 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -88,7 +88,7 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" -tosa_reference_model_rev="444eb365d92774430006e56a8c20161be2f2674f" +tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5" ######## ### Mandatory user args diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml index 9958047f4a..4c16e3a994 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/AndroidManifest.xml @@ -10,7 +10,7 @@ { + try { + InputMethodManager imm = (InputMethodManager) getSystemService(INPUT_METHOD_SERVICE); + imm.hideSoftInputFromWindow(getCurrentFocus().getWindowToken(), 0); + } catch (Exception e) { + ETLogging.getInstance().log("Keyboard dismissal error: " + e.getMessage()); + } addSelectedImagesToChatThread(mSelectedImageUri); String finalPrompt; String rawPrompt = mEditTextMessage.getText().toString(); diff --git a/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md b/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md index 5f850e31f3..54bf956176 100644 --- a/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md +++ b/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md @@ -6,13 +6,13 @@ More specifically, it covers: 2. Building and linking libraries that are required to inference on-device for Android platform using Qualcomm AI accelerators. 3. Building the Android demo app itself. -Verified on Linux CentOS, QNN SDK [v2.26](https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.26.0.240828.zip), python 3.10, Android SDK r26c. +Verified on Linux CentOS, QNN SDK [v2.26](https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.26.0.240828.zip), python 3.10, Android SDK r27b. Phone verified: OnePlus 12, Samsung 24+, Samsung 23 ## Prerequisites * Download and unzip QNN SDK [v2.26](https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.26.0.240828.zip) -* Download and unzip Android SDK [r26](https://developer.android.com/ndk/downloads) +* Download and unzip Android SDK [r27b](https://developer.android.com/ndk/downloads) * Android phone with Snapdragon8 Gen3 (SM8650) or Gen2 (SM8550). Gen 1 and lower SoC might be supported but not fully validated. * Desired Llama model weights in .PTH format. You can download them on HuggingFace ([Example](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)). @@ -40,7 +40,7 @@ Install dependencies ## Setup QNN ``` # Set these variables correctly for your environment -export ANDROID_NDK_ROOT=$HOME/android-ndk-r26 # Download android SDK and unzip to home directory +export ANDROID_NDK_ROOT=$HOME/android-ndk-r27b # Download android SDK and unzip to home directory export QNN_SDK_ROOT=$HOME/Your-SDK-Root #Folder contains lib export EXECUTORCH_ROOT=$HOME/repos/executorch export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/:$LD_LIBRARY_PATH diff --git a/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh b/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh index df70725942..53d248bc75 100644 --- a/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh +++ b/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh @@ -7,66 +7,19 @@ set -eu -CMAKE_OUT="${CMAKE_OUT:-cmake-out-android}" -# Note: Set up ANDROID_NDK and ANDROID_ABI -cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ - -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ - -DANDROID_ABI="${ANDROID_ABI}" \ - -DEXECUTORCH_BUILD_XNNPACK=ON \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_QNN=ON \ - -DQNN_SDK_ROOT="${QNN_SDK_ROOT}" \ - -DCMAKE_BUILD_TYPE=Release \ - -B"${CMAKE_OUT}" - -if [ "$(uname)" == "Darwin" ]; then - CMAKE_JOBS=$(( $(sysctl -n hw.ncpu) - 1 )) -else - CMAKE_JOBS=$(( $(nproc) - 1 )) +if [ -z "$QNN_SDK_ROOT" ]; then + echo "You must specify QNN_SDK_ROOT" + exit 1 fi -cmake --build "${CMAKE_OUT}" -j "${CMAKE_JOBS}" --target install --config Release - -cmake extension/android \ - -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI="${ANDROID_ABI}" \ - -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ - -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -B"${CMAKE_OUT}"/extension/android -cmake --build "${CMAKE_OUT}"/extension/android -j "${CMAKE_JOBS}" --config Release - -JNI_LIBS_PATH="examples/demo-apps/android/LlamaDemo/app/src/main/jniLibs" -mkdir -p "${JNI_LIBS_PATH}/${ANDROID_ABI}" +BASEDIR=$(dirname "$0") +source "$BASEDIR"/../../../../build/build_android_llm_demo.sh BUILD_AAR_DIR="$(mktemp -d)" -mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" "${BUILD_AAR_DIR}/libs" -JNI_LIBS_PATH="${BUILD_AAR_DIR}/jni" -cp "${CMAKE_OUT}"/extension/android/libexecutorch_jni.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/libexecutorch.so" -cp "${CMAKE_OUT}"/lib/libqnn_executorch_backend.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtp.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnSystem.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV69Stub.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV73Stub.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV75Stub.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/hexagon-v69/unsigned/libQnnHtpV69Skel.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp "${QNN_SDK_ROOT}"/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" -cp extension/android/build/libs/executorch.jar "${BUILD_AAR_DIR}/libs" -echo \ \ - \ \ - \ > "${BUILD_AAR_DIR}/AndroidManifest.xml" -pushd "${BUILD_AAR_DIR}" -zip -r executorch-llama.aar libs jni/${ANDROID_ABI} AndroidManifest.xml -popd -mkdir -p examples/demo-apps/android/LlamaDemo/app/libs -mv "${BUILD_AAR_DIR}/executorch-llama.aar" examples/demo-apps/android/LlamaDemo/app/libs +export BUILD_AAR_DIR + +build_jar +build_android_native_library "arm64-v8a" +build_aar +mkdir -p "$BASEDIR"/app/libs +cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/app/libs/executorch-llama.aar diff --git a/examples/demo-apps/android/LlamaDemo/setup.sh b/examples/demo-apps/android/LlamaDemo/setup.sh index b89c182994..140434a86d 100644 --- a/examples/demo-apps/android/LlamaDemo/setup.sh +++ b/examples/demo-apps/android/LlamaDemo/setup.sh @@ -7,53 +7,15 @@ set -eu -CMAKE_OUT="${CMAKE_OUT:-cmake-out-android}" -# Note: Set up ANDROID_NDK and ANDROID_ABI -cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ - -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ - -DANDROID_ABI="${ANDROID_ABI}" \ - -DANDROID_PLATFORM=android-23 \ - -DEXECUTORCH_BUILD_XNNPACK=ON \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -B"${CMAKE_OUT}" - -if [ "$(uname)" == "Darwin" ]; then - CMAKE_JOBS=$(( $(sysctl -n hw.ncpu) - 1 )) -else - CMAKE_JOBS=$(( $(nproc) - 1 )) -fi -cmake --build "${CMAKE_OUT}" -j "${CMAKE_JOBS}" --target install --config Release - -cmake extension/android \ - -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI="${ANDROID_ABI}" \ - -DANDROID_PLATFORM=android-23 \ - -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ - -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -B"${CMAKE_OUT}"/extension/android - -cmake --build "${CMAKE_OUT}"/extension/android -j "${CMAKE_JOBS}" --config Release +BASEDIR=$(dirname "$0") +source "$BASEDIR"/../../../../build/build_android_llm_demo.sh BUILD_AAR_DIR="$(mktemp -d)" -mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" "${BUILD_AAR_DIR}/libs" -cp "${CMAKE_OUT}"/extension/android/libexecutorch_jni.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/libexecutorch.so" -cp extension/android/build/libs/executorch.jar "${BUILD_AAR_DIR}/libs" -echo \ \ - \ \ - \ > "${BUILD_AAR_DIR}/AndroidManifest.xml" -pushd "${BUILD_AAR_DIR}" -zip -r executorch-llama.aar libs jni/${ANDROID_ABI} AndroidManifest.xml -popd -mkdir -p examples/demo-apps/android/LlamaDemo/app/libs -mv "${BUILD_AAR_DIR}/executorch-llama.aar" examples/demo-apps/android/LlamaDemo/app/libs +export BUILD_AAR_DIR + +build_jar +build_android_native_library "arm64-v8a" +build_android_native_library "x86_64" +build_aar +mkdir -p "$BASEDIR"/app/libs +cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/app/libs/executorch-llama.aar diff --git a/examples/models/checkpoint.py b/examples/models/checkpoint.py index cd916142d9..592ebab145 100644 --- a/examples/models/checkpoint.py +++ b/examples/models/checkpoint.py @@ -67,7 +67,7 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: if value.dtype != dtype ] if len(mismatched_dtypes) > 0: - raise ValueError( + print( f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" ) return dtype diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index 1a6fe99fc4..2260c8f825 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -356,7 +356,7 @@ To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNN **1. Build llama runner binary for Android** -*Pre-requisite*: Android NDK (tested with r26c) which can be downloaded from [here](https://developer.android.com/ndk/downloads). Note that the mac binary can be unpackaged and you can locate NDK folder from it. +*Pre-requisite*: Android NDK (tested with r27b) which can be downloaded from [here](https://developer.android.com/ndk/downloads). Note that the mac binary can be unpackaged and you can locate NDK folder from it. **1.1 Set Android NDK** ``` diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 13bf0a52c2..17597f3d50 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -3,6 +3,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_common_targets") +load("@fbsource//xplat/executorch/backends/qualcomm/qnn_version.bzl", "get_qnn_library_verision") oncall("executorch") @@ -69,7 +70,7 @@ runtime.python_binary( runtime.command_alias( name = "export_llama_qnn", env = { - "LD_LIBRARY_PATH": "$(location fbsource//third-party/qualcomm/qnn/qnn-2.25:qnn_offline_compile_libs)", + "LD_LIBRARY_PATH": "$(location fbsource//third-party/qualcomm/qnn/qnn-{0}:qnn_offline_compile_libs)".format(get_qnn_library_verision()), }, exe = ":export_llama", ) diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 57258fdbc1..95b3ff0fb7 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -194,7 +194,7 @@ def gen_eval_wrapper( manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args) if len(quantizers) != 0: - manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers) + manager = manager.export().pt2e_quantize(quantizers) model = ( manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore if torch.cuda.is_available() @@ -209,7 +209,7 @@ def gen_eval_wrapper( ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch - # for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but + # for quantizers. Currently export_for_training only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index cdac837fd2..8cff6e8e11 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -581,7 +581,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 # export_to_edge builder_exported_to_edge = ( _prepare_for_llama_export(modelname, args) - .capture_pre_autograd_graph() + .export() .pt2e_quantize(quantizers) .export_to_edge() ) @@ -650,6 +650,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0: model_sharding.split_graph( builder_exported_to_edge.edge_manager.exported_program(), + # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. builder_exported_to_edge.metadata["get_n_layers"], shares=args.num_sharding, ) @@ -669,6 +670,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program + # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) builder = builder.to_executorch() @@ -686,6 +688,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program + # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) builder = builder.to_executorch() @@ -912,7 +915,6 @@ def _get_source_transforms( # noqa if args.use_kv_cache: if args.qnn: - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` from executorch.backends.qualcomm.utils.utils import ( convert_linear_to_conv2d, ) @@ -924,6 +926,7 @@ def _get_source_transforms( # noqa if args.optimized_rotation_path: transforms.append(fuse_layer_norms) transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) + # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. transforms.append(convert_linear_to_conv2d) elif args.mps: diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index 7ef51ac93c..3879c00fd9 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -494,6 +494,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel( group_size=group_size, dtype=child.weight.dtype, packed=packed, + bitwidth=bitwidth, ), ) else: @@ -519,14 +520,17 @@ def __init__( self.group_size = group_size self.bitwidth = bitwidth self.packed = packed - if (bitwidth != 4) and packed: - raise RuntimeError("pack only works with bitsize 4") + if (bitwidth not in [2, 4]) and packed: + raise RuntimeError("pack only works with bitsize 2, 4") @torch.no_grad() def create_quantized_state_dict(self, packed=False) -> Dict: cur_state_dict = self.mod.state_dict() - if self.bitwidth == 4: + if self.bitwidth == 2: + range_min = -2 + range_max = 1 + elif self.bitwidth == 4: range_min = -8 range_max = 7 elif self.bitwidth == 8: @@ -555,17 +559,30 @@ def create_quantized_state_dict(self, packed=False) -> Dict: ) if packed: - if weight.shape[-1] % 2 != 0: - raise RuntimeError("automatic padding not implemented yet") - - weight_range_shifted = weight.add(8).view(torch.uint8) - weight_view = weight_range_shifted.view( - weight.shape[0], weight.shape[1] // 2, 2 - ) - weight_even = weight_view[:, :, 0] * 16 # left shift 4 - weight_odd = weight_view[:, :, 1] - weight_packed = weight_even + weight_odd - weight = weight_packed + if self.bitwidth == 2: + if weight.shape[-1] % 4 != 0: + raise RuntimeError("automatic padding not implemented yet") + weight_range_shifted = weight.add(2).view(torch.uint8) + weight_view = weight_range_shifted.view( + weight.shape[0], weight.shape[1] // 4, 4 + ) + weight_0 = weight_view[:, :, 0] + weight_1 = weight_view[:, :, 1] << 2 + weight_2 = weight_view[:, :, 2] << 4 + weight_3 = weight_view[:, :, 3] << 6 + weight_packed = weight_0 + weight_1 + weight_2 + weight_3 + weight = weight_packed + elif self.bitwidth == 4: + if weight.shape[-1] % 2 != 0: + raise RuntimeError("automatic padding not implemented yet") + weight_range_shifted = weight.add(8).view(torch.uint8) + weight_view = weight_range_shifted.view( + weight.shape[0], weight.shape[1] // 2, 2 + ) + weight_even = weight_view[:, :, 0] * 16 # left shift 4 + weight_odd = weight_view[:, :, 1] + weight_packed = weight_even + weight_odd + weight = weight_packed weight = weight.to(device=self.device) scales = scales.to(device=self.device) @@ -598,6 +615,7 @@ def __init__( group_size: Optional[int] = None, dtype=torch.half, packed=False, + bitwidth: int = 8, ) -> None: super().__init__() if group_size is None or group_size == 0: @@ -605,6 +623,7 @@ def __init__( self.group_size = group_size self.dtype = dtype self.packed = packed + self.bitwidth = bitwidth if not packed: self.register_buffer( "weight", @@ -613,12 +632,25 @@ def __init__( ), ) else: # packed - self.register_buffer( - "weight", - torch.empty( - (vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device - ), - ) + if bitwidth == 2: + self.register_buffer( + "weight", + torch.empty( + (vocab_size, embedding_dim // 4), + dtype=torch.uint8, + device=device, + ), + ) + elif bitwidth == 4: + self.register_buffer( + "weight", + torch.empty( + (vocab_size, embedding_dim // 2), + dtype=torch.uint8, + device=device, + ), + ) + groups_per_row = (embedding_dim + group_size - 1) // group_size if groups_per_row > 1: self.register_buffer( @@ -638,7 +670,14 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: return torch.ops.quantized_decomposed.embedding_byte.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) - else: # 4bit packed + else: # packed + if self.bitwidth == 2: + return torch.ops.quantized_decomposed.embedding_2bit.dtype( + self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype + ) + + # Remaining case (always return to make pyre happy) + assert self.bitwidth == 4 return torch.ops.quantized_decomposed.embedding_4bit.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) @@ -658,7 +697,7 @@ def get_quant_embedding_transform(args): model, bitwidth=bitwidth, group_size=group_size, - packed=(bitwidth == 4), + packed=(bitwidth in [2, 4]), ).quantized_model() diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 47a5407cf1..cc475f1b19 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -53,7 +53,7 @@ class LlavaEdgeManager(LLMEdgeManager): - def capture_pre_autograd_graph(self) -> "LlavaEdgeManager": + def export(self) -> "LlavaEdgeManager": dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -107,7 +107,7 @@ def forward(self, input_pos, embeddings): text_model_em.set_output_dir("./") .to_dtype(dtype_override) .source_transform(source_transforms) - .capture_pre_autograd_graph() + .export() .pt2e_quantize(quantizers) ) @@ -148,7 +148,7 @@ def forward(self, images): dynamic_shapes=dynamic_shapes, args=None, ) - .capture_pre_autograd_graph() + .export() .pt2e_quantize([quantizer]) ) diff --git a/examples/models/phi-3-mini/phi_3_mini.py b/examples/models/phi-3-mini/phi_3_mini.py index a7c906f0cd..b8cd5ef384 100644 --- a/examples/models/phi-3-mini/phi_3_mini.py +++ b/examples/models/phi-3-mini/phi_3_mini.py @@ -17,17 +17,22 @@ def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int super().__init__() self.model = model self.cache = ETStaticCache( + # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `config`. config=model.config, max_batch_size=max_batch_size, max_cache_len=max_seq_len, + # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `device`. device=self.model.device, + # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `dtype`. dtype=self.model.dtype, ) def forward( self, + # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`. input_ids: torch.LongTensor = None, ) -> torch.FloatTensor: + # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `forward`. return self.model.forward( input_ids=input_ids, use_cache=True, diff --git a/examples/models/phi-3-mini/static_cache.py b/examples/models/phi-3-mini/static_cache.py index 2469ae742d..baf66ac2d1 100644 --- a/examples/models/phi-3-mini/static_cache.py +++ b/examples/models/phi-3-mini/static_cache.py @@ -34,6 +34,7 @@ def __init__( ) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + # pyre-fixme[16]: `ETStaticCache` has no attribute `key_cache`. return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item() def get_usable_length( diff --git a/examples/portable/scripts/export.py b/examples/portable/scripts/export.py index ec829aa2a7..353d8a034e 100644 --- a/examples/portable/scripts/export.py +++ b/examples/portable/scripts/export.py @@ -65,9 +65,7 @@ def main() -> None: backend_config = ExecutorchBackendConfig() if args.segment_alignment is not None: backend_config.segment_alignment = int(args.segment_alignment, 16) - if ( - dynamic_shapes is not None - ): # capture_pre_autograd_graph does not work with dynamic shapes + if dynamic_shapes is not None: edge_manager = export_to_edge( model, example_inputs, diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index fb21da4e9c..5434efea06 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -394,7 +394,7 @@ def compile(args): end_quantize_ts = time.time() print("single_llama.quantize(quant_dtype)", end_quantize_ts - start_quantize_ts) single_llama.lowering_modules( - args.artifact, kv_type=kv_type, soc_model=get_soc_to_chipset_map[args.model] + args.artifact, kv_type=kv_type, soc_model=get_soc_to_chipset_map()[args.model] ) end_lowering_ts = time.time() print("Complete Compile", end_lowering_ts - end_quantize_ts) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index f10babc5bb..27c9db2ffc 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -83,7 +83,7 @@ def __init__( self.dump_intermediate_outputs = dump_intermediate_outputs self.debug_output_path = f"{self.workspace}/debug_output.bin" self.output_folder = f"{self.workspace}/outputs" - self.soc_model = get_soc_to_arch_map()[soc_model] + self.htp_arch = get_soc_to_arch_map()[soc_model] self.error_only = error_only self.shared_buffer = shared_buffer self.runner = runner @@ -108,12 +108,12 @@ def push(self, inputs=None, input_list=None, files=None): *self.pte_path, f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtp.so", ( - f"{self.qnn_sdk}/lib/hexagon-v{self.soc_model}/" - f"unsigned/libQnnHtpV{self.soc_model}Skel.so" + f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/" + f"unsigned/libQnnHtpV{self.htp_arch}Skel.so" ), ( f"{self.qnn_sdk}/lib/aarch64-android/" - f"libQnnHtpV{self.soc_model}Stub.so" + f"libQnnHtpV{self.htp_arch}Stub.so" ), f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtpPrepare.so", f"{self.qnn_sdk}/lib/aarch64-android/libQnnSystem.so", diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index 88d082b87d..529b22b1d0 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -172,6 +172,138 @@ def embedding_byte_dtype_out_meta( ) +quantized_decomposed_lib.define( + "embedding_2bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", +) + +quantized_decomposed_lib.define( + "embedding_2bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor", +) + +quantized_decomposed_lib.define( + "embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", +) + +quantized_decomposed_lib.define( + "embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", +) + + +@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd") +def embedding_2bit( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, +) -> torch.Tensor: + embedding_weight_checks(weight, weight_scales, weight_zero_points) + group_size = (4 * weight.size(1)) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight_0 = weight & 3 + weight_1 = (weight & 12) >> 2 + weight_2 = (weight & 48) >> 4 + weight_3 = (weight & 192) >> 6 + weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1) + weight = weight_unpacked.view(weight.shape[0], -1) + weight = weight.view(torch.int8).add(-2) + + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + weight_scales.dtype, + ) + return torch.ops.aten.embedding.default(weight, indices) + + +@register_fake("quantized_decomposed::embedding_2bit.out") +def embedding_2bit_out_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + return embedding_2bit( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + ) + + +@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd") +def embedding_2bit_dtype( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + embedding_weight_checks(weight, weight_scales, weight_zero_points) + group_size = (4 * weight.size(1)) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight_0 = weight & 3 + weight_1 = (weight & 12) >> 2 + weight_2 = (weight & 48) >> 4 + weight_3 = (weight & 192) >> 6 + weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1) + weight = weight_unpacked.view(weight.shape[0], -1) + weight = weight.view(torch.int8).add(-2) + + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + dtype, + ) + return torch.ops.aten.embedding.default(weight, indices) + + +@register_fake("quantized_decomposed::embedding_2bit.dtype_out") +def embedding_2bit_dtype_out_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], + out: torch.Tensor, +) -> torch.Tensor: + return embedding_2bit_dtype( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + dtype, + ) + + quantized_decomposed_lib.define( "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index bf2b99da9d..0502c47dbb 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -51,7 +51,7 @@ def get_control_flow_submodules_list(graph_module): for node in current_graph_module.graph.nodes: if "debug_handle" in node.meta: max_handle = max(max_handle, node.meta["debug_handle"]) - control_flow_submodules = get_control_flow_submodules_list(ep.graph_module) + control_flow_submodules = get_control_flow_submodules_list(current_graph_module) queue.extend(control_flow_submodules) queue = [ep.graph_module] @@ -61,5 +61,5 @@ def get_control_flow_submodules_list(graph_module): if node.meta.get("debug_handle", 0) in (0, None): node.meta["debug_handle"] = max_handle + 1 max_handle += 1 - control_flow_submodules = get_control_flow_submodules_list(ep.graph_module) + control_flow_submodules = get_control_flow_submodules_list(current_graph_module) queue.extend(control_flow_submodules) diff --git a/exir/program/_program.py b/exir/program/_program.py index 144cd0d0e8..fa1d84db7c 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -972,36 +972,39 @@ def to_edge_transform_and_lower( exported programs in ATen dialect. It differs fundamentally from to_edge in that it combines the conversion of the ATen dialect to the edge dialect program, then running the transformation passes and then subsequently lowering the programs to their - corresponding backends all in a single pass. + corresponding backends all into a single API. + This is fundamentally useful for lowering to backends that have ops registered that they do not want to be decomposed and thus rely on matching with these non-decomposed ops. For these sorts of backends this is the *only* API that should be used to lower to the edge dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent or wrong behavior. + This API is the primary recommended way to lower to the CPU based XNNPack backend. + Args: programs: Can be a single ExportedProgram or a dictionary mapping function names - to their corresponding ExportedPrograms. If only a single ExportedProgram is - provided it will be assigned the name "forward". + to their corresponding ExportedPrograms. If only a single ExportedProgram is + provided it will be assigned the name "forward". transform_passes: The passes can either be a list of passes, or a dictionary - mapping method names to lists of passes. If it is just a list of passes, all methods - in the given EdgeProgramManager will be transformed with the provided passes. If it - is a dictionary, only method names specified in the dictionary will be transformed - with their corresponding passes. + mapping method names to lists of passes. If it is just a list of passes, all methods + in the given EdgeProgramManager will be transformed with the provided passes. If it + is a dictionary, only method names specified in the dictionary will be transformed + with their corresponding passes. partitioner: The partitioner can either be a Partitioner subclass instance, or a - dictionary mapping method names to Partitioner subclass instance. If it is a - Partitioner subclass, all programs in the given EdgeProgramManager will be lowered - using the given partitioner. If it is a dictionary, only method names specified in - the dictionary will be lowered with the given partitioner. + dictionary mapping method names to Partitioner subclass instance. If it is a + Partitioner subclass, all programs in the given EdgeProgramManager will be lowered + using the given partitioner. If it is a dictionary, only method names specified in + the dictionary will be lowered with the given partitioner. constant_methods: An optional dictionary of method name to the constant value - returned by that method in eager mode. Often used to store config information on - Edge models. + returned by that method in eager mode. Often used to store config information on + Edge models. compile_config: An optional argument used to provide greater control over the - transformation to edge dialect process. + transformation to edge dialect process. Returns: EdgeProgramManager diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index be2bb4be33..a3aa1f3997 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1018,16 +1018,13 @@ def forward( torch.ones(2, 2), ) - graph_module = ( - to_edge( - export( - f, - inputs, - ) + ep = to_edge( + export( + f, + inputs, ) - .exported_program() - .graph_module - ) + ).exported_program() + graph_module = ep.graph_module def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: queue = [graph_module] @@ -1045,6 +1042,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: DebugHandleGeneratorPass()(graph_module) check_debug_handle_metadata(graph_module) + generate_missing_debug_handles(ep) # Check debug handle still preserved after ScalarToTensorPass ScalarToTensorPass()(graph_module) diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 2397bcfb85..e2d46f8e8d 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -42,7 +42,7 @@ protected void onCreate(Bundle savedInstanceState) { .findFirst() .get(); - int numIter = intent.getIntExtra("num_iter", 10); + int numIter = intent.getIntExtra("num_iter", 50); // TODO: Format the string with a parsable format Stats stats = new Stats(); diff --git a/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm b/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm index ce68533576..f6c6927e78 100644 --- a/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm +++ b/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm @@ -85,7 +85,10 @@ @implementation GenericTests XCTFail("Unsupported tag %i at input %d", *input_tag, index); } } + XCTMeasureOptions *options = [[XCTMeasureOptions alloc] init]; + options.iterationCount = 20; [testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] + options:options block:^{ XCTAssertEqual(module->forward().error(), Error::Ok); }]; diff --git a/extension/llm/README.md b/extension/llm/README.md index ddcf4c727d..7f4baed7d3 100644 --- a/extension/llm/README.md +++ b/extension/llm/README.md @@ -10,7 +10,7 @@ Commonly used methods in this class include: - _source_transform_: execute a series of source transform passes. Some transform passes include - weight only quantization, which can be done at source (eager mode) level. - replace some torch operators to a custom operator. For example, _replace_sdpa_with_custom_op_. -- _capture_pre_autograd_graph_: get a graph that is ready for pt2 graph-based quantization. +- _torch.export_for_training_: get a graph that is ready for pt2 graph-based quantization. - _pt2e_quantize_ with passed in quantizers. - util functions in _quantizer_lib.py_ can help to get different quantizers based on the needs. - _export_to_edge_: export to edge dialect diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 37d90f8595..d2a413fc79 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -82,7 +82,7 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - # graph module returned from capture_pre_autograd_graph + # graph module returned from export() self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None self.modelname = modelname self.max_seq_len = max_seq_len @@ -176,15 +176,16 @@ def _get_edge_config(self) -> EdgeCompileConfig: ) return edge_config - def capture_pre_autograd_graph(self) -> "LLMEdgeManager": + def export(self) -> "LLMEdgeManager": dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - # pyre-fixme[8] if hasattr(self.args, "qnn") and self.args.qnn: # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details + # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as + # `Module`. self.pre_autograd_graph_module = torch.export.export( self.model, self.example_inputs, @@ -193,6 +194,8 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": strict=True, ).module() else: + # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as + # `Module`. self.pre_autograd_graph_module = export_for_training( self.model, self.example_inputs, @@ -293,7 +296,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage composed_quantizer = ComposableQuantizer(quantizers) assert ( self.pre_autograd_graph_module is not None - ), "Please run capture_pre_autograd_graph first" + ), "Please run export() first" m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" @@ -341,8 +344,8 @@ def export_to_edge(self) -> "LLMEdgeManager": # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): if self.pre_autograd_graph_module is None: - # Run capture_pre_autograd_graph if it didn't run - self.capture_pre_autograd_graph() + # Run export() if it didn't run + self.export() self.edge_manager = export_to_edge( self.pre_autograd_graph_module, # pyre-fixme[6] self.example_inputs, diff --git a/extension/llm/tokenizer/utils.py b/extension/llm/tokenizer/utils.py index 97aa4bf0c0..763febdf47 100644 --- a/extension/llm/tokenizer/utils.py +++ b/extension/llm/tokenizer/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.llama2.tokenizer.tiktoken`. from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 367c23f081..d9721e5055 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -144,7 +144,7 @@ def define_libs(): ( "^android-arm64.*$", [ - "fbsource//third-party/openblas:openblas", + "fbsource//arvr/third-party/eigen:eigen3_blas", ], ), ], diff --git a/kernels/quantized/cpu/embeddingxb.cpp b/kernels/quantized/cpu/embeddingxb.cpp new file mode 100644 index 0000000000..2b149646e8 --- /dev/null +++ b/kernels/quantized/cpu/embeddingxb.cpp @@ -0,0 +1,396 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using Scalar = exec_aten::Scalar; +using ScalarType = exec_aten::ScalarType; + +namespace { + +static inline int32_t +weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) { + if (weight_nbit == 2) { + int32_t subbyte = index % 4; + index >>= 2; + switch (subbyte) { + case 0: + return (int32_t)(w_data[index] & 3) - 2; + case 1: + return (int32_t)((w_data[index] & 12) >> 2) - 2; + case 2: + return (int32_t)((w_data[index] & 48) >> 4) - 2; + case 3: + return (int32_t)((w_data[index] & 192) >> 6) - 2; + } + } else if (weight_nbit == 4) { + int32_t odd = index & 1; + index >>= 1; + if (odd) { + return (int32_t)(w_data[index] & 0x0F) - 8; + } else { + return (int32_t)((w_data[index] >> 4) & 0x0F) - 8; + } + } + + ET_CHECK_MSG(false, "invalid weight_nbit"); +} + +static inline int32_t get_embedding_dim( + int32_t packed_dim, + int32_t weight_nbit) { + ET_CHECK_MSG(8 % weight_nbit == 0, "invalid embedding dim"); + int packed_values_per_byte = 8 / weight_nbit; + return packed_dim * packed_values_per_byte; +} + +/** + * Asserts that the parameters are valid. + */ +void check_embedding_xbit_args( + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out, + int weight_nbit) { + ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8"); + + ET_CHECK_MSG( + weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim()); + + ET_CHECK_MSG( + weight_scales.dim() == 1 || weight_scales.dim() == 2, + "weight_scales must be 1D or 2D but got() %zd dims", + weight_scales.dim()); + + ET_CHECK_MSG( + weight_scales.size(0) == weight.size(0), + "Number of scales must be == weight.size(0)=%zd" + ", but got %zd", + weight_scales.size(0), + weight.size(0)); + + if (weight_scales.dim() == 2) { + auto num_groups = weight_scales.size(1); + ET_CHECK_MSG( + // each 8b uint8 column is packed_values_per_byte columns + get_embedding_dim(weight.size(1), weight_nbit) % num_groups == 0, + "Number of groups must divide weight.size(1)=%zd" + ", but got # of groups = %zd", + weight.size(1), + num_groups); + } + + ET_CHECK_MSG( + weight.scalar_type() == ScalarType::Byte, + "weight.scalar_type() %" PRId8 " is not supported:", + static_cast(weight.scalar_type())); + + ET_CHECK_MSG( + out.scalar_type() == ScalarType::Float || + out.scalar_type() == ScalarType::Half, + "out.scalar_type() %" PRId8 " is not supported:", + static_cast(out.scalar_type())); + + ET_CHECK_MSG( + weight_scales.scalar_type() == ScalarType::Float || + weight_scales.scalar_type() == ScalarType::Half, + "weight_scales.scalar_type() %" PRId8 " is not supported:", + static_cast(weight_scales.scalar_type())); + + if (opt_weight_zero_points.has_value()) { + ET_CHECK_MSG( + opt_weight_zero_points.value().dim() == weight_scales.dim(), + "weight_zero_points's rank match that of weight_scales. " + "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8, + static_cast(opt_weight_zero_points.value().dim()), + static_cast(weight_scales.dim())); + + ET_CHECK_MSG( + opt_weight_zero_points.value().scalar_type() == out.scalar_type(), + "weight zero points scalar type %" PRId8 + " does not match out.scalar_type()", + static_cast(opt_weight_zero_points.value().scalar_type())); + + for (int32_t i = 0; i < weight_scales.dim(); ++i) { + ET_CHECK_MSG( + opt_weight_zero_points.value().size(i) == weight_scales.size(i), + "Dimension size misatch at dim %" PRId8 + "Weight_zero_point size = %zd" + ", weight_scales size = %zd.", + i, + opt_weight_zero_points.value().size(i), + weight_scales.size(i)); + } + } + + ET_CHECK_MSG( + indices.scalar_type() == ScalarType::Long, + "indices.scalar_type() %" PRId8 " is not Long only Long is supported:", + static_cast(indices.scalar_type())); + + ET_CHECK_MSG( + weight_quant_min <= weight_quant_max, + "weight quant min: %" PRId64 + " is greater than weight quant max: %" PRId64, + weight_quant_min, + weight_quant_max); + + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out.scalar_type() == out_dtype.value(), + "output_dtype must match the dtype of the out tensor"); + } +} + +/** + * Retrieves the embeddings specified by indices, dequantizes them, and stores + * them in out. Weight will always be uint8 + */ +template +void embedding_xbit_per_channel( + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const Tensor& indices, + Tensor& out, + int weight_nbit) { + auto embedding_dim = get_embedding_dim(weight.size(1), weight_nbit); + + int32_t num_groups_per_channel = 1; + if (weight_scales.dim() == 2) { + num_groups_per_channel = weight_scales.size(1); + } + int32_t group_size = embedding_dim / num_groups_per_channel; + + CTYPE_OUT* out_data = out.mutable_data_ptr(); + const int64_t* indices_ptr = indices.const_data_ptr(); + + const CTYPE_PARAMS* scales = weight_scales.const_data_ptr(); + const CTYPE_PARAMS* zero_points = nullptr; + if (opt_weight_zero_points.has_value()) { + zero_points = opt_weight_zero_points.value().const_data_ptr(); + } + + for (int i = 0; i < indices.numel(); i++) { + int64_t index = indices_ptr[i]; + // If using groupwise embedding + int32_t qparams_index = index * num_groups_per_channel; + CTYPE_PARAMS zp = 0.0; + const CTYPE_PARAMS* scale_ptr = scales + qparams_index; + const CTYPE_PARAMS* zero_points_ptr = nullptr; + if (opt_weight_zero_points.has_value()) { + zero_points_ptr = zero_points + qparams_index; + } + + const uint8_t* w_data = + weight.const_data_ptr() + weight.size(1) * index; + + for (int j = 0; j < embedding_dim; ++j) { + int32_t group_id = j / group_size; + const CTYPE_PARAMS scale = scale_ptr[group_id]; + if (opt_weight_zero_points.has_value()) { + zp = zero_points_ptr[group_id]; + } + out_data[j] = static_cast( + (static_cast(weight_value(w_data, j, weight_nbit)) - + static_cast(zp)) * + static_cast(scale)); + } + out_data += embedding_dim; + } +} + +void resize_out_tensor( + const Tensor& weight, + const Tensor& indices, + Tensor& out, + int weight_nbit) { + exec_aten::SizesType expected_output_size[kTensorDimensionLimit]; + for (size_t i = 0; i < indices.dim(); i++) { + expected_output_size[i] = indices.size(i); + } + const size_t embedding_dim = get_embedding_dim(weight.size(1), weight_nbit); + expected_output_size[out.dim() - 1] = embedding_dim; + + exec_aten::ArrayRef output_size{ + expected_output_size, static_cast(out.dim())}; + + torch::executor::Error err = resize_tensor(out, output_size); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantized_embedding_xbit_out"); +} + +} // namespace + +/** + * Retrieves the embeddings specified by indices, dequantizes them, and stores + * them in out. The weight is quantized per channel, with a scale and zero_point + * for each embedding. + * + * Corresponds as the out variant to torch.ops.quantized.embedding_xbit + * + * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather + * metadata that is passed around which can be useful for pattern matching. See + * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more + * info. + */ +Tensor& quantized_embedding_xbit_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + Tensor& out, + int weight_nbit) { + ScalarType out_type = out.scalar_type(); + + // TODO (jakeszwe): improve these to account for the size of out in relation + // to weight and indices accounting for a possible batch dimension + check_embedding_xbit_args( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_type, + out, + weight_nbit); + + constexpr auto name = "quantized_decomposed::embedding_xbit.out"; + ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { + embedding_xbit_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out, + weight_nbit); + }); + + return out; +} + +Tensor& quantized_embedding_xbit_out( + KernelRuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + Tensor& out, + int weight_nbit) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + resize_out_tensor(weight, indices, out, weight_nbit); + return quantized_embedding_xbit_out( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out, + weight_nbit); +} + +Tensor& quantized_embedding_xbit_dtype_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out, + int weight_nbit) { + // TODO (jakeszwe): improve these to account for the size of out in relation + // to weight and indices accounting for a possible batch dimension + check_embedding_xbit_args( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_dtype, + out, + weight_nbit); + + ScalarType params_type = weight_scales.scalar_type(); + ScalarType out_type = out.scalar_type(); + + constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out"; + ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() { + ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { + embedding_xbit_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out, + weight_nbit); + }); + }); + + return out; +} + +Tensor& quantized_embedding_xbit_dtype_out( + KernelRuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out, + int weight_nbit) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + resize_out_tensor(weight, indices, out, weight_nbit); + return quantized_embedding_xbit_dtype_out( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_dtype, + out, + weight_nbit); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/quantized/cpu/embeddingxb.h b/kernels/quantized/cpu/embeddingxb.h new file mode 100644 index 0000000000..ae1fccc6c2 --- /dev/null +++ b/kernels/quantized/cpu/embeddingxb.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using Scalar = exec_aten::Scalar; +using ScalarType = exec_aten::ScalarType; + +Tensor& quantized_embedding_xbit_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + Tensor& out, + int weight_nbit); + +Tensor& quantized_embedding_xbit_out( + KernelRuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + Tensor& out, + int weight_nbit); + +Tensor& quantized_embedding_xbit_dtype_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out, + int weight_nbit); + +Tensor& quantized_embedding_xbit_dtype_out( + KernelRuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out, + int weight_nbit); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/quantized/cpu/op_embedding2b.cpp b/kernels/quantized/cpu/op_embedding2b.cpp new file mode 100644 index 0000000000..0fdd7b731f --- /dev/null +++ b/kernels/quantized/cpu/op_embedding2b.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using Scalar = exec_aten::Scalar; +using ScalarType = exec_aten::ScalarType; + +/** + * Retrieves the embeddings specified by indices, dequantizes them, and stores + * them in out. The weight is quantized per channel, with a scale and zero_point + * for each embedding. + * + * Corresponds as the out variant to torch.ops.quantized.embedding_2bit + * + * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather + * metadata that is passed around which can be useful for pattern matching. See + * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more + * info. + */ +Tensor& quantized_embedding_2bit_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + Tensor& out) { + return quantized_embedding_xbit_out( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out, + 2); +} + +Tensor& quantized_embedding_2bit_out( + KernelRuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + Tensor& out) { + return quantized_embedding_xbit_out( + context, + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out, + 2); +} + +Tensor& quantized_embedding_2bit_dtype_out( + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out) { + return quantized_embedding_xbit_dtype_out( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_dtype, + out, + 2); +} + +Tensor& quantized_embedding_2bit_dtype_out( + KernelRuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out) { + return quantized_embedding_xbit_dtype_out( + context, + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_dtype, + out, + 2); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/quantized/cpu/op_embedding4b.cpp b/kernels/quantized/cpu/op_embedding4b.cpp index df553d2594..8a99073cd0 100644 --- a/kernels/quantized/cpu/op_embedding4b.cpp +++ b/kernels/quantized/cpu/op_embedding4b.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -19,197 +20,6 @@ using Tensor = exec_aten::Tensor; using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; -namespace { - -/** - * Asserts that the parameters are valid. - */ -void check_embedding_4bit_args( - const Tensor& weight, - const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const int64_t weight_quant_min, - const int64_t weight_quant_max, - const Tensor& indices, - exec_aten::optional out_dtype, - Tensor& out) { - ET_CHECK_MSG( - weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim()); - - ET_CHECK_MSG( - weight_scales.dim() == 1 || weight_scales.dim() == 2, - "weight_scales must be 1D or 2D but got() %zd dims", - weight_scales.dim()); - - ET_CHECK_MSG( - weight_scales.size(0) == weight.size(0), - "Number of scales must be == weight.size(0)=%zd" - ", but got %zd", - weight_scales.size(0), - weight.size(0)); - - if (weight_scales.dim() == 2) { - auto num_groups = weight_scales.size(1); - ET_CHECK_MSG( - // each 8b uint8 column is 2 columns - (2 * weight.size(1)) % num_groups == 0, - "Number of groups must divide weight.size(1)=%zd" - ", but got # of groups = %zd", - weight.size(1), - num_groups); - } - - ET_CHECK_MSG( - weight.scalar_type() == ScalarType::Byte, - "weight.scalar_type() %" PRId8 " is not supported:", - static_cast(weight.scalar_type())); - - ET_CHECK_MSG( - out.scalar_type() == ScalarType::Float || - out.scalar_type() == ScalarType::Half, - "out.scalar_type() %" PRId8 " is not supported:", - static_cast(out.scalar_type())); - - ET_CHECK_MSG( - weight_scales.scalar_type() == ScalarType::Float || - weight_scales.scalar_type() == ScalarType::Half, - "weight_scales.scalar_type() %" PRId8 " is not supported:", - static_cast(weight_scales.scalar_type())); - - if (opt_weight_zero_points.has_value()) { - ET_CHECK_MSG( - opt_weight_zero_points.value().dim() == weight_scales.dim(), - "weight_zero_points's rank match that of weight_scales. " - "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8, - static_cast(opt_weight_zero_points.value().dim()), - static_cast(weight_scales.dim())); - - ET_CHECK_MSG( - opt_weight_zero_points.value().scalar_type() == out.scalar_type(), - "weight zero points scalar type %" PRId8 - " does not match out.scalar_type()", - static_cast(opt_weight_zero_points.value().scalar_type())); - - for (int32_t i = 0; i < weight_scales.dim(); ++i) { - ET_CHECK_MSG( - opt_weight_zero_points.value().size(i) == weight_scales.size(i), - "Dimension size misatch at dim %" PRId8 - "Weight_zero_point size = %zd" - ", weight_scales size = %zd.", - i, - opt_weight_zero_points.value().size(i), - weight_scales.size(i)); - } - } - - ET_CHECK_MSG( - indices.scalar_type() == ScalarType::Long, - "indices.scalar_type() %" PRId8 " is not Long only Long is supported:", - static_cast(indices.scalar_type())); - - ET_CHECK_MSG( - weight_quant_min <= weight_quant_max, - "weight quant min: %" PRId64 - " is greater than weight quant max: %" PRId64, - weight_quant_min, - weight_quant_max); - - if (out_dtype.has_value()) { - ET_CHECK_MSG( - out.scalar_type() == out_dtype.value(), - "output_dtype must match the dtype of the out tensor"); - } -} - -static inline int32_t weight_value(const unsigned char* w_data, int32_t index) { - int32_t odd = index & 1; - index >>= 1; - if (odd) { - return (int32_t)(w_data[index] & 0x0F) - 8; - } else { - return (int32_t)((w_data[index] >> 4) & 0x0F) - 8; - } -} - -/** - * Retrieves the embeddings specified by indices, dequantizes them, and stores - * them in out. Weight will always be uint8 - */ -template -void embedding_4bit_per_channel( - const Tensor& weight, - const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const Tensor& indices, - Tensor& out) { - auto embedding_dim = weight.size(1) * 2; - - int32_t num_groups_per_channel = 1; - if (weight_scales.dim() == 2) { - num_groups_per_channel = weight_scales.size(1); - } - int32_t group_size = embedding_dim / num_groups_per_channel; - - CTYPE_OUT* out_data = out.mutable_data_ptr(); - const int64_t* indices_ptr = indices.const_data_ptr(); - - const CTYPE_PARAMS* scales = weight_scales.const_data_ptr(); - const CTYPE_PARAMS* zero_points = nullptr; - if (opt_weight_zero_points.has_value()) { - zero_points = opt_weight_zero_points.value().const_data_ptr(); - } - - for (int i = 0; i < indices.numel(); i++) { - int64_t index = indices_ptr[i]; - // If using groupwise embedding - int32_t qparams_index = index * num_groups_per_channel; - CTYPE_PARAMS zp = 0.0; - const CTYPE_PARAMS* scale_ptr = scales + qparams_index; - const CTYPE_PARAMS* zero_points_ptr = nullptr; - if (opt_weight_zero_points.has_value()) { - zero_points_ptr = zero_points + qparams_index; - } - - const uint8_t* w_data = - weight.const_data_ptr() + weight.size(1) * index; - - for (int j = 0; j < embedding_dim; ++j) { - int32_t group_id = j / group_size; - const CTYPE_PARAMS scale = scale_ptr[group_id]; - if (opt_weight_zero_points.has_value()) { - zp = zero_points_ptr[group_id]; - } - out_data[j] = static_cast( - (static_cast(weight_value(w_data, j)) - - static_cast(zp)) * - static_cast(scale)); - } - out_data += embedding_dim; - } -} - -void resize_out_tensor( - const Tensor& weight, - const Tensor& indices, - Tensor& out) { - exec_aten::SizesType expected_output_size[kTensorDimensionLimit]; - for (size_t i = 0; i < indices.dim(); i++) { - expected_output_size[i] = indices.size(i); - } - const size_t embedding_dim = weight.size(1) * 2; - expected_output_size[out.dim() - 1] = embedding_dim; - - exec_aten::ArrayRef output_size{ - expected_output_size, static_cast(out.dim())}; - - torch::executor::Error err = resize_tensor(out, output_size); - ET_CHECK_MSG( - err == torch::executor::Error::Ok, - "Failed to resize out Tensor in quantized_embedding_4bit_out"); -} - -} // namespace - /** * Retrieves the embeddings specified by indices, dequantizes them, and stores * them in out. The weight is quantized per channel, with a scale and zero_point @@ -232,27 +42,15 @@ Tensor& quantized_embedding_4bit_out( const int64_t weight_quant_max, const Tensor& indices, Tensor& out) { - ScalarType out_type = out.scalar_type(); - - // TODO (jakeszwe): improve these to account for the size of out in relation - // to weight and indices accounting for a possible batch dimension - check_embedding_4bit_args( + return quantized_embedding_xbit_out( weight, weight_scales, opt_weight_zero_points, weight_quant_min, weight_quant_max, indices, - out_type, - out); - - constexpr auto name = "quantized_decomposed::embedding_4bit.out"; - ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - embedding_4bit_per_channel( - weight, weight_scales, opt_weight_zero_points, indices, out); - }); - - return out; + out, + 4); } Tensor& quantized_embedding_4bit_out( @@ -264,18 +62,16 @@ Tensor& quantized_embedding_4bit_out( int64_t weight_quant_max, const Tensor& indices, Tensor& out) { - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - resize_out_tensor(weight, indices, out); - return quantized_embedding_4bit_out( + return quantized_embedding_xbit_out( + context, weight, weight_scales, opt_weight_zero_points, weight_quant_min, weight_quant_max, indices, - out); + out, + 4); } Tensor& quantized_embedding_4bit_dtype_out( @@ -289,9 +85,7 @@ Tensor& quantized_embedding_4bit_dtype_out( const Tensor& indices, exec_aten::optional out_dtype, Tensor& out) { - // TODO (jakeszwe): improve these to account for the size of out in relation - // to weight and indices accounting for a possible batch dimension - check_embedding_4bit_args( + return quantized_embedding_xbit_dtype_out( weight, weight_scales, opt_weight_zero_points, @@ -299,20 +93,8 @@ Tensor& quantized_embedding_4bit_dtype_out( weight_quant_max, indices, out_dtype, - out); - - ScalarType params_type = weight_scales.scalar_type(); - ScalarType out_type = out.scalar_type(); - - constexpr auto name = "quantized_decomposed::embedding_4bit.dtype_out"; - ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() { - ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - embedding_4bit_per_channel( - weight, weight_scales, opt_weight_zero_points, indices, out); - }); - }); - - return out; + out, + 4); } Tensor& quantized_embedding_4bit_dtype_out( @@ -325,11 +107,8 @@ Tensor& quantized_embedding_4bit_dtype_out( const Tensor& indices, exec_aten::optional out_dtype, Tensor& out) { - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - resize_out_tensor(weight, indices, out); - return quantized_embedding_4bit_dtype_out( + return quantized_embedding_xbit_dtype_out( + context, weight, weight_scales, opt_weight_zero_points, @@ -337,7 +116,8 @@ Tensor& quantized_embedding_4bit_dtype_out( weight_quant_max, indices, out_dtype, - out); + out, + 4); } } // namespace native diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index d2bbbfebe0..3ba9715506 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -23,8 +23,19 @@ _QUANT_OPS = ( op_target( name = "op_embedding", ), + op_target( + name = "op_embedding2b", + deps = ["//executorch/kernels/quantized/cpu:embeddingxb"], + _aten_mode_deps = [ + "//executorch/kernels/quantized/cpu:embeddingxb_aten", + ], + ), op_target( name = "op_embedding4b", + deps = ["//executorch/kernels/quantized/cpu:embeddingxb"], + _aten_mode_deps = [ + "//executorch/kernels/quantized/cpu:embeddingxb_aten", + ], ), op_target( name = "op_mixed_mm", @@ -65,6 +76,26 @@ def define_common_targets(): exported_deps = quant_op_targets, ) + runtime.cxx_library( + name = "embeddingxb", + srcs = ["embeddingxb.cpp"], + exported_headers = ["embeddingxb.h"], + visibility = [ + "//executorch/kernels/quantized/...", + ], + deps = ["//executorch/runtime/kernel:kernel_includes"], + ) + + runtime.cxx_library( + name = "embeddingxb_aten", + srcs = ["embeddingxb.cpp"], + exported_headers = ["embeddingxb.h"], + visibility = [ + "//executorch/kernels/quantized/...", + ], + deps = ["//executorch/runtime/kernel:kernel_includes_aten"], + ) + runtime.cxx_library( name = "quantized_cpu_aten", srcs = [], diff --git a/kernels/quantized/quantized.yaml b/kernels/quantized/quantized.yaml index eb7586bad7..f1012e7005 100644 --- a/kernels/quantized/quantized.yaml +++ b/kernels/quantized/quantized.yaml @@ -46,6 +46,18 @@ - arg_meta: null kernel_name: torch::executor::quantized_embedding_byte_dtype_out +- func: quantized_decomposed::embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::quantized_embedding_2bit_out + +- func: quantized_decomposed::embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::quantized_embedding_2bit_dtype_out + - func: quantized_decomposed::embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: diff --git a/kernels/quantized/test/CMakeLists.txt b/kernels/quantized/test/CMakeLists.txt index 8df233d931..9605facbf6 100644 --- a/kernels/quantized/test/CMakeLists.txt +++ b/kernels/quantized/test/CMakeLists.txt @@ -25,6 +25,7 @@ set(_kernels_quantized_test_sources op_add_test.cpp op_choose_qparams_test.cpp op_dequantize_test.cpp + op_embedding2b_test.cpp op_embedding4b_test.cpp op_embedding_test.cpp op_mixed_linear_test.cpp diff --git a/kernels/quantized/test/op_embedding2b_test.cpp b/kernels/quantized/test/op_embedding2b_test.cpp new file mode 100644 index 0000000000..1e4633f30c --- /dev/null +++ b/kernels/quantized/test/op_embedding2b_test.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include +#include + +using namespace ::testing; +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using torch::executor::native::quantized_embedding_2bit_out; + +using torch::executor::testing::TensorFactory; + +TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -2; + int64_t quant_max = 1; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, -2, 0}); + + // -2, 1, 0, 1, -> 0, 3, 2, 3 -> (reverse) 11 10 11 00 -> 236 + // 0, -1, -2, 0, -> 2, 1, 0, 2 -> (reverse) 10 00 01 10 -> 134 + // -2, -1, 0, 1, -> 0, 1, 2, 3 -> (reverse) 11 10 01 00 -> 228 + + Tensor qweight = tfb.make({3, 1}, {236, 134, 228}); + + Tensor indices = tfl.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, 2.0, 1.0, 0.0, 2.0}); + + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + + out = tf.zeros({3, 4}); + auto context = KernelRuntimeContext(); + torch::executor::native::quantized_embedding_2bit_out( + context, + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + + // Groupwise quantization. groupsize = 2 + + weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}); + weight_zero_points = tf.make({3, 2}, {1, -2, 0, 1, -2, -1}); + + // -2, 1, 0, 1, -> 0, 3, 2, 3 -> (reverse) 11 10 11 00 -> 236 + // 0, -1, -2, 0, -> 2, 1, 0, 2 -> (reverse) 10 00 01 10 -> 134 + // -2, -1, 0, 1, -> 0, 1, 2, 3 -> (reverse) 11 10 01 00 -> 228 + + qweight = tfb.make({3, 1}, {236, 134, 228}); + + indices = tfl.make({3}, {0, 2, 1}); + + out = tf.zeros({3, 4}); + expected = tf.make( + {3, 4}, {-1.5, 0.0, 2.0, 3.0, 0.0, 2.5, 3.0, 6.0, 0.0, -1.5, -6.0, -2.0}); + + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath1) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -2; + int64_t quant_max = 1; + + Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3}); + Tensor weight_zero_points = tf.make({4}, {1, -2, 1, 0}); + Tensor qweight = tfb.make({3, 1}, {236, 134, 228}); + Tensor indices = tfl.make({3}, {0, 2, 1}); + Tensor out = tf.zeros({3, 4}); + + // qvals are incompatible shape with scales/zeros + ET_EXPECT_DEATH( + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath2) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -2; + int64_t quant_max = 1; + + Tensor weight_scales = tf.make({2}, {0.5, 1.0}); + Tensor weight_zero_points = tf.make({2}, {1, -2}); + Tensor qweight = tfb.make({3, 1}, {236, 134, 228}); + Tensor indices = tfl.make({3}, {0, 2, 1}); + Tensor out = tf.zeros({3, 4}); + + // qvals are incompatible shape with scales/zeros + ET_EXPECT_DEATH( + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath3) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -2; + int64_t quant_max = 1; + + Tensor weight_scales = tf.make({2, 3}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + Tensor weight_zero_points = tf.make({2, 3}, {0, 0, 0, 0, 0, 0}); + Tensor qweight = tfb.make({2, 1}, {236, 134}); + Tensor indices = tfl.make({2}, {0, 2}); + Tensor out = tf.zeros({2, 8}); + + // scales/zeros imply 3 groups, which does not divide embed dimension from + // qvals (8) + ET_EXPECT_DEATH( + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} diff --git a/kernels/quantized/test/targets.bzl b/kernels/quantized/test/targets.bzl index a4129ee22f..a107efc397 100644 --- a/kernels/quantized/test/targets.bzl +++ b/kernels/quantized/test/targets.bzl @@ -25,6 +25,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu:op_embedding", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ]) + op_test("op_embedding2b_test", kernel_name = "quantized") op_test("op_embedding4b_test", kernel_name = "quantized") op_test("op_mixed_mm_test", kernel_name = "quantized", deps = [ "//executorch/kernels/quantized/cpu:op_mixed_mm", diff --git a/shim/xplat/executorch/backends/qualcomm/qnn_version.bzl b/shim/xplat/executorch/backends/qualcomm/qnn_version.bzl new file mode 100644 index 0000000000..75019982af --- /dev/null +++ b/shim/xplat/executorch/backends/qualcomm/qnn_version.bzl @@ -0,0 +1,2 @@ +def get_qnn_library_verision(): + return "2.26"