diff --git a/.ci/docker/ci_commit_pins/buck2.txt b/.ci/docker/ci_commit_pins/buck2.txt index 8c22caa617..1b22c8ffc0 100644 --- a/.ci/docker/ci_commit_pins/buck2.txt +++ b/.ci/docker/ci_commit_pins/buck2.txt @@ -1 +1 @@ -2024-05-15 +2024-12-16 diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 9c24945f7d..9182b03d38 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -2ea4b56ec872424e486c4fe2d55da061067a2ed3 +0a94bb432ed75cc2d950d81b2921363218a7e459 diff --git a/.ci/docker/conda-env-ci.txt b/.ci/docker/conda-env-ci.txt index 9ebfe654da..8f2e65dae7 100644 --- a/.ci/docker/conda-env-ci.txt +++ b/.ci/docker/conda-env-ci.txt @@ -1,2 +1,4 @@ cmake=3.22.1 ninja=1.10.2 +libuv +pkg-config diff --git a/.ci/scripts/setup-arm-baremetal-tools.sh b/.ci/scripts/setup-arm-baremetal-tools.sh new file mode 100755 index 0000000000..454b9f336e --- /dev/null +++ b/.ci/scripts/setup-arm-baremetal-tools.sh @@ -0,0 +1,11 @@ +#!/bin/bash +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# NB: This function could be used to install Arm dependencies +# Setup arm example environment (including TOSA tools) +git config --global user.email "github_executorch@arm.com" +git config --global user.name "Github Executorch" +bash examples/arm/setup.sh --i-agree-to-the-contained-eula diff --git a/.ci/scripts/setup-macos.sh b/.ci/scripts/setup-macos.sh index b1a8ff14b5..395f0c1767 100755 --- a/.ci/scripts/setup-macos.sh +++ b/.ci/scripts/setup-macos.sh @@ -131,5 +131,9 @@ if [[ -z "${GITHUB_RUNNER:-}" ]]; then fi print_cmake_info -install_executorch +install_pytorch_and_domains +# We build PyTorch from source here instead of using nightly. This allows CI to test against +# the pinned commit from PyTorch +install_executorch "use-pt-pinned-commit" build_executorch_runner "${BUILD_TOOL}" +do_not_use_nightly_on_ci diff --git a/.ci/scripts/utils.sh b/.ci/scripts/utils.sh index be927a3a3a..ebed4a3150 100644 --- a/.ci/scripts/utils.sh +++ b/.ci/scripts/utils.sh @@ -40,6 +40,42 @@ install_pip_dependencies() { popd || return } +install_domains() { + echo "Install torchvision and torchaudio" + pip install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" + pip install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" +} + +install_pytorch_and_domains() { + pushd .ci/docker || return + TORCH_VERSION=$(cat ci_commit_pins/pytorch.txt) + popd || return + + git clone https://github.com/pytorch/pytorch.git + + # Fetch the target commit + pushd pytorch || return + git checkout "${TORCH_VERSION}" + git submodule update --init --recursive + + export USE_DISTRIBUTED=1 + # Then build and install PyTorch + python setup.py bdist_wheel + pip install "$(echo dist/*.whl)" + + # Grab the pinned audio and vision commits from PyTorch + TORCHAUDIO_VERSION=$(cat .github/ci_commit_pins/audio.txt) + export TORCHAUDIO_VERSION + TORCHVISION_VERSION=$(cat .github/ci_commit_pins/vision.txt) + export TORCHVISION_VERSION + + install_domains + + popd || return + # Print sccache stats for debugging + sccache --show-stats || true +} + install_flatc_from_source() { # NB: This function could be used to install flatbuffer from source pushd third-party/flatbuffers || return @@ -59,17 +95,6 @@ install_flatc_from_source() { popd || return } -install_arm() { - # NB: This function could be used to install Arm dependencies - # Setup arm example environment (including TOSA tools) - git config --global user.email "github_executorch@arm.com" - git config --global user.name "Github Executorch" - bash examples/arm/setup.sh --i-agree-to-the-contained-eula - - # Test tosa_reference flow - source examples/arm/ethos-u-scratch/setup_path.sh -} - build_executorch_runner_buck2() { # Build executorch runtime with retry as this step is flaky on macos CI retry buck2 build //examples/portable/executor_runner:executor_runner diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index f4fbc0128b..4fb3e9711d 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -410,7 +410,7 @@ jobs: runs-on: linux.2xlarge steps: - name: Download the apps from GitHub - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: # The name here needs to match the name of the upload-artifact parameter name: ios-apps diff --git a/.github/workflows/apple.yml b/.github/workflows/apple.yml index f284d466bf..2e85eeec76 100644 --- a/.github/workflows/apple.yml +++ b/.github/workflows/apple.yml @@ -53,7 +53,7 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 secrets-env: BUILD_CERTIFICATE_BASE64 EXECUTORCH_DEMO_BUILD_PROVISION_PROFILE_BASE64 KEYCHAIN_PASSWORD - upload-artifact: ios-apps + upload-artifact: ios-demo-app script: | set -eux @@ -83,10 +83,10 @@ jobs: runs-on: linux.2xlarge steps: - name: Download the artifacts from GitHub - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: # The name here needs to match the name of the upload-artifact parameter - name: ios-apps + name: ios-demo-app path: ${{ runner.temp }}/artifacts/ - name: Verify the artifacts @@ -216,7 +216,7 @@ jobs: role-to-assume: arn:aws:iam::308535385114:role/gha_executorch_upload-frameworks-ios aws-region: us-east-1 - name: Download the artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: # NB: The name here needs to match the upload-artifact name from build-frameworks-ios job name: executorch-frameworks-ios @@ -291,7 +291,7 @@ jobs: python-version: '3.11' submodules: 'true' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - upload-artifact: ios-apps + upload-artifact: ios-benchmark-app secrets-env: BUILD_CERTIFICATE_BASE64 EXECUTORCH_BENCHMARK_BUILD_PROVISION_PROFILE_BASE64 KEYCHAIN_PASSWORD timeout: 90 script: | diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 5941ab52e7..d1b64e7598 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -354,13 +354,11 @@ jobs: EXECUTORCH_BUILD_ARM_BAREMETAL=ON \ .ci/scripts/setup-linux.sh "${BUILD_TOOL}" - source .ci/scripts/utils.sh # Install Arm dependencies - install_arm - - # Run pytest with coverage - pytest -c /dev/null -v -n auto --cov=./ --cov-report=xml backends/arm/test + .ci/scripts/setup-arm-baremetal-tools.sh + # Run pytest without simulator + backends/arm/test/test_arm_baremetal.sh test_pytest test-llama-runner-qnn-linux: name: test-llama-runner-qnn-linux diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 7972269e92..90bd0eb6ef 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -146,14 +146,15 @@ jobs: source .ci/scripts/utils.sh install_executorch - install_arm + .ci/scripts/setup-arm-baremetal-tools.sh # Increase number of files user can monitor to bypass buck failures. # Hopefully this is high enough for this setup. sudo sysctl fs.inotify.max_user_watches=1048576 # 1024 * 1024 # Test ethos-u delegate examples with run.sh - PYTHON_EXECUTABLE=python bash examples/arm/run.sh examples/arm/ethos-u-scratch/ + backends/arm/test/test_arm_baremetal.sh test_run_ethosu_fvp + test-arm-reference-delegation: name: test-arm-reference-delegation @@ -172,10 +173,10 @@ jobs: source .ci/scripts/utils.sh install_executorch - install_arm + .ci/scripts/setup-arm-baremetal-tools.sh - # Run arm unit tests - pytest -c /dev/null -v -n auto --cov=./ --cov-report=xml backends/arm/test + # Run arm unit tests using the simulator + backends/arm/test/test_arm_baremetal.sh test_pytest_ethosu_fvp test-coreml-delegate: name: test-coreml-delegate diff --git a/.gitmodules b/.gitmodules index 58f2133ed6..afae765e2b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -66,7 +66,7 @@ url = https://github.com/pybind/pybind11.git [submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"] path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 - url = https://github.com/foss-xtensa/nnlib-FusionG3/ + url = https://github.com/foss-xtensa/nnlib-FusionG3.git [submodule "third-party/ao"] path = third-party/ao url = https://github.com/pytorch/ao.git diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6af6939471..bd943c587b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -197,7 +197,7 @@ If it's not clear how to add a test for your PR, take a look at the blame for the code you're modifying and find an author who has more context. Ask them for their help in the PR comments. -TODO: Explain how to run tests locally without needing to push and wait for CI. +The `test/run_oss_cpp_tests.sh` script will build and run C++ tests locally. ### Continuous Integration See https://hud.pytorch.org/hud/pytorch/executorch/main for the current state of diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index 8362774fa9..749f32a04e 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -32,6 +32,9 @@ CompileSpec, PreprocessResult, ) + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass +from executorch.exir.program._program import _transform from torch.export.exported_program import ExportedProgram FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -83,6 +86,9 @@ def preprocess( # FlatBuffer graph, process the `output` nodes and add their id to # the `output_ids` array in the schema. + # TODO: Remove this once we have a better support for the dim-order ops. + edge_program = _transform(edge_program, DimOrderOpsRevertPass()) + mps_graph = MPSGraph( version="0", mps_nodes=[], diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py index dacb09215c..f8dcfc66a0 100644 --- a/backends/apple/mps/operators/constant_ops.py +++ b/backends/apple/mps/operators/constant_ops.py @@ -79,6 +79,25 @@ def define_node( ) +@register_node_visitor +class ToDimOrderEmptyVisitor(NodeVisitor): + target = ["dim_order_ops._empty_dim_order.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError( + "dim_order_ops._empty_dim_order.default is not supported yet" + ) + + @register_node_visitor class FullLikeVisitor(NodeVisitor): target = "aten.full_like.default" diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index 2310ae02da..0d0b7f5363 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -33,3 +33,22 @@ def define_node( ) input_id = self.define_tensor(get_input_node(node, 0), mps_graph) self.tensor_to_id[node] = input_id + + +@register_node_visitor +class ToDimOrderCopyVisitor(NodeVisitor): + target = ["dim_order_ops._to_dim_order_copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError( + "dim_order_ops._to_dim_order_copy.default is not supported yet" + ) diff --git a/backends/apple/mps/test/test_mps.py b/backends/apple/mps/test/test_mps.py index fe64a30f3c..a981d7ab8e 100644 --- a/backends/apple/mps/test/test_mps.py +++ b/backends/apple/mps/test/test_mps.py @@ -1829,6 +1829,21 @@ def forward(self, x): Clone(), model_inputs, func_name=inspect.stack()[0].function[5:] ) + def test_mps_backend_to_copy(self): + class Copy(torch.nn.Module): + def forward(self, x): + return ( + torch.ops.aten._to_copy.default( + x + 2, memory_format=torch.contiguous_format + ) + + x + ) + + model_inputs = (torch.randn(1, 3, 3),) + self.lower_and_test_with_partitioner( + Copy(), model_inputs, func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_floor(self): class Floor(torch.nn.Module): def forward(self, x): diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 43ae9aa0f0..a39583c66c 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -26,10 +26,7 @@ # Config for Capturing the weights, will be moved in the future -# TODO(T182928844): Delegate dim order op to backend. -_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( - _check_ir_validity=False, _skip_dim_order=True -) +_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False) class ansi_colors: @@ -219,7 +216,6 @@ def lower_module_and_test_output( dynamic_shapes=dynamic_shapes, edge_compile_config=EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. ), ) @@ -250,7 +246,6 @@ def lower_module_and_test_output( export(delegated_program, sample_inputs, strict=True), compile_config=exir.EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. ), ).to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) diff --git a/backends/arm/README.md b/backends/arm/README.md index a7458db07c..2457ec074b 100644 --- a/backends/arm/README.md +++ b/backends/arm/README.md @@ -39,6 +39,28 @@ Other: - `third-party/` - Dependencies on other code - in particular the TOSA serialization_lib for compiling to TOSA and the ethos-u-core-driver for the bare-metal backend supporting Ethos-U - `test/` - Unit test and test support functions +## Testing + +After a setup you can run unit tests with the test_arm_baremetal.sh script. + +To run the pytests suite run + +``` +backends/arm/test/test_arm_baremetal.sh test_pytest +``` + +To run the unit test suite with Corstone3x0 FVP simulator support use + +``` +backends/arm/test/test_arm_baremetal.sh test_pytest_ethosu_fvp +``` + +You can test to run some models with the run.sh flow + +``` +backends/arm/test/test_arm_baremetal.sh test_run_ethosu_fvp +``` + ## Unit tests This is the structure of the test directory @@ -51,6 +73,8 @@ test # Root test folder ├── tester # Arm Tester class ├── tosautil # Utility functions for TOSA artifacts ├ common.py # Common functions and definitions used by many tests +├ setup_testing.sh # Script to prepare testing for using the Corstone 3x0 FVP +├ test_arm_baremetal.sh # Help script to trigger testing ``` Some example commands to run these tests follow. Run a single test: @@ -59,6 +83,12 @@ Some example commands to run these tests follow. Run a single test: python -m unittest backends.arm.test.ops.test_add.TestSimpleAdd -k test_add2_tosa_BI ``` +or with pytest + +``` +pytest -c /dev/null -v -n auto backends/arm/test/ops/test_add.py -k test_add2_tosa_BI +``` + Or all tests in "TestSimpleAdd": ``` @@ -71,6 +101,50 @@ Or discover and run many tests: python -m unittest discover -s backends/arm/test/ops/ ``` +or with pytest + +``` +pytest -c /dev/null -v -n auto backends/arm/test/ops/ +``` + + +You can run tests using Corstone3x0 simulators to see how it would work on something more target like +first you need to build and prepare some used target libs + +``` +examples/arm/run.sh --model_name=add --build_only +backends/arm/test/setup_testing.sh +``` + +The you can run the tests with + +``` +pytest -c /dev/null -v -n auto backends/arm/test --arm_quantize_io --arm_run_corstoneFVP +``` + +### Code coverage + +To get code coverage: + +``` +coverage run --source= --rcfile=backends/arm/test/.coveragerc -m pytest \ +--config-file=/dev/null backends/arm/test/ +``` + +All files in `SRC` and its child directories will be analysed for code coverage, +unless explicitly exluded in the .coveragerc file. If using venv this might be +under `env/lib/python/site-packages/executorch/`. To get the +absolute path, run: + +``` +python -c "import executorch; print(executorch.__path__)" +``` + +This contains a list of paths where the source directory is located. Pick the +one that is located in `env/lib`. If that does not work try the others. Add +`backends/arm` to the path in `--source` to only get code coverage for the Arm +backend. + ### A note on unit tests There are currently 3 ways we unit test our code. diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 034939945f..0846d97372 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -36,7 +36,6 @@ def call(self, graph_module: GraphModule) -> PassResult: itertools.chain.from_iterable(matmul_partitions.values()) ) matmul_targets = { - exir_ops.edge.aten.mm.default, exir_ops.edge.aten.bmm.default, } for partition in matmul_partitions: diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 6d747d8129..9c2a074372 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -37,6 +37,9 @@ QuantizeFullArgument, RetraceFoldedDtypesPass, ) +from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( + FuseQuantizedActivationPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, @@ -45,6 +48,7 @@ from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, ) +from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass from executorch.backends.arm._passes.scalars_to_attribute_pass import ( ScalarsToAttributePass, @@ -72,6 +76,7 @@ def transform_to_backend_pipeline( self, exported_program: ExportedProgram, compile_spec: list[CompileSpec] ): """Apply passes before transforming program to backend""" + self.add_pass(FuseQuantizedActivationPass()) self.add_pass(DecomposeLinearPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(DecomposeLayerNormPass()) @@ -79,6 +84,7 @@ def transform_to_backend_pipeline( self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeMeanDimPass()) self.add_pass(ConvertSplitToSlicePass()) + self.add_pass(ConvertMmToBmmPass()) # TODO MLETORCH-558 self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeFullArgument()) @@ -99,7 +105,6 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.max_pool2d.default, - exir_ops.edge.aten.mm.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.reciprocal.default, diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 71dbdb5b85..aee68f74eb 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -8,7 +8,6 @@ from typing import cast -from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -25,14 +24,14 @@ def call_operator(self, op, args, kwargs, meta): if op != self.expand_copy: return super().call_operator(op, args, kwargs, meta) - _, shape, _ = extract_tensor_meta(meta.data) + input_shape = args[0].data.shape multiples = cast(list[int], args[1]) expanded_rank = len(multiples) - # Expanded shape is 'shape' front-padded with ones. - padding = expanded_rank - len(shape) + # Expanded shape is 'input_shape' front-padded with ones. + padding = expanded_rank - len(input_shape) extended_shape = [ - shape[i] if i >= 0 else 1 for i in range(-padding, len(shape)) + input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape)) ] # To convert expand arg to repeat arg, non-repeated dims should have diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 283760e423..73747d8313 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta): sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) full = super().call_operator( full_op, - ([1 for _ in shape], 1 / max(0, N - correction)), + ([], 1 / max(0, N - correction)), {"dtype": dtype}, meta, ) diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py new file mode 100644 index 0000000000..86836842bb --- /dev/null +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -0,0 +1,60 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa_quant_utils import q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import Node + + +class FuseQuantizedActivationPass(ExportPass): + def _is_fuseable_quantized_activation(self, node: Node): + """Fuse activations that have a 0 lower bound and quantized with a qmin zero-point""" + is_fuseable = node.target == exir_ops.edge.aten.relu.default + if node.target == exir_ops.edge.aten.hardtanh.default: + min_val = node.args[1] + is_fuseable = min_val == 0 + + is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op + if is_quantized: + quant_node = next(iter(node.users)) + zp = quant_node.args[2] + qmin = quant_node.args[3] + + return is_fuseable and is_quantized and zp == qmin + + def _is_fuseable_input(self, node: Node): + return ( + node.target + in ( + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.linear.default, + ) + and len(node.users) == 1 + ) + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if not self._is_fuseable_quantized_activation(node): + continue + + input_node = node.args[0] + if not self._is_fuseable_input(input_node): + continue + + node.replace_all_uses_with(input_node) + graph_module.graph.erase_node(node) + modified = True + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index e230fde1a0..941d20c95a 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -90,7 +90,7 @@ def call(self, graph_module: GraphModule) -> PassResult: continue # Calculate max rank of all inputs to node - max_rank = 1 + max_rank = 0 for arg in node.args: if isinstance(arg, Node): shape = get_first_fake_tensor(arg).shape diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py new file mode 100644 index 0000000000..c0c0d67c0b --- /dev/null +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -0,0 +1,98 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, + insert_q_dq_pair, +) +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import Node + + +class ConvertMmToBmmPass(ExportPass): + """ + This pass converts a MM node to a BMM one and turns input and output tensors + from rank 2 to rank 3. The TOSA specification requires rank 3. The graph is + modified to do the following: + 1) Unsqueeze input tensors to rank 3. + 2) Convert MM node to BMM. + 3) Squeeze output tensor to rank 2. + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified_graph = False + graph = graph_module.graph + node_list = graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mm.default + ) + for node in node_list: + # Unsqueeze input tensors to rank 3 + for input_node in node.args: + if not isinstance(input_node, Node): + continue + + shape = get_first_fake_tensor(input_node).shape + rank = len(shape) + if rank != 2: + raise RuntimeError(f"Input tensor has rank {rank}, must be 2") + + with graph.inserting_before(node): + unsqueeze_before = create_node( + graph, exir_ops.edge.aten.unsqueeze_copy.default + ) + unsqueeze_before.args = ( + input_node, # Input is node's original input + 0, + ) + node.replace_input_with(input_node, unsqueeze_before) + + # If Quantized we must insert unsqueeze --> q --> dq --> node + if input_node.target == dq_op: + q_params = input_node.args[1:] + insert_q_dq_pair(graph, unsqueeze_before, q_params) + + # Replace mm node with bmm + with graph.inserting_before(node): + bmm_node = create_node( + graph, + exir_ops.edge.aten.bmm.default, + ) + bmm_node.args = node.args + node.replace_all_uses_with(bmm_node) + graph.erase_node(node) + + # Unsqueeze output tensor to rank 3 + with graph.inserting_after(bmm_node): + squeeze_after = create_node( + graph, + exir_ops.edge.aten.squeeze_copy.dims, + ) + squeeze_after.args = ( + bmm_node, + [0], + ) + original_users = [ + user for user in bmm_node.users if user != squeeze_after + ] + for user in original_users: + user.replace_input_with(bmm_node, squeeze_after) + + # If quantized, insert mm --> q --> dq --> squeeze + if all(original_user.target == q_op for original_user in original_users): + q_params = original_users[0].args[1:] + insert_q_dq_pair(graph, bmm_node, q_params) + + modified_graph = True + + if modified_graph: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified_graph) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 38b73fe0f6..ee5f2807a9 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. +# Copyright 2023-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -22,7 +22,6 @@ op_max, op_max_pool2d, op_min, - op_mm, op_mul, op_permute, op_quant, diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py deleted file mode 100644 index 9266b5a03c..0000000000 --- a/backends/arm/operators/op_mm.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch - -# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import build_reshape, expand_dims -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class MMVisitor_080_BI(NodeVisitor): - target = "aten.mm.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - # For atem.mm, the two inputs are of rank 2 - # For TOSA it needs to be rank 3 - # So they need to be reshaped from (H, W) to (1, H, W) - reshape_dtype = output.dtype - input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0) - input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0) - - # The output also needs to be rank 3 - output_new_shape = (1, output.shape[0], output.shape[1]) - - input_qparams = get_input_qparams(node) # pyre-ignore[16] - assert len(input_qparams) == 2 - input0_qparams = input_qparams[0] - input1_qparams = input_qparams[1] - - mat_mul_result = tosa_graph.addIntermediate(output_new_shape, ts.DType.INT32) - - attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(A_zp=input0_qparams.zp, B_zp=input1_qparams.zp) - - tosa_graph.addOperator( - TosaOp.Op().MATMUL, - [input0_reshaped.name, input1_reshaped.name], - [mat_mul_result.name], - attr, - ) - - reshape_intermediate = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - reshape_output_name = reshape_intermediate.name - - # Reshape the final output back to rank 2 - build_reshape( - tosa_graph, mat_mul_result.name, output.shape, reshape_output_name - ) - - # As INT8 accumulates into INT32, we need to rescale it back to INT8 - output_qparams = get_output_qparams(node) # pyre-ignore[16] - assert len(output_qparams) == 1 - - final_output_scale = ( - input0_qparams.scale * input1_qparams.scale - ) / output_qparams[0].scale - - # As the input will be INT32, the input_zp must be set to 0 - build_rescale( - tosa_fb=tosa_graph, - scale=final_output_scale, - input_node=reshape_intermediate, - output_name=output.name, - output_type=output.dtype, - output_shape=reshape_intermediate.shape, - input_zp=0, - output_zp=output_qparams[0].zp, - is_double_round=False, - ) - - -@register_node_visitor -class MMVisitor_080_MI(MMVisitor_080_BI): - # inheriting 'target' from BI class - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output) - reshape_dtype = output.dtype - # For atem.mm, the two inputs are of rank 2 - # For TOSA it needs to be rank 3 - # So they need to be reshaped from (H, W) to (1, H, W) - input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0) - input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0) - - # The output also needs to be rank 3 - output_new_shape = (1, output.shape[0], output.shape[1]) - - # Set zps to 0 - input0_zp, input1_zp = 0, 0 - attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp) - mat_mul_result = tosa_graph.addIntermediate(output_new_shape, output.dtype) - reshape_output_name = output.name - - tosa_graph.addOperator( - TosaOp.Op().MATMUL, - [input0_reshaped.name, input1_reshaped.name], - [mat_mul_result.name], - attr, - ) - # Reshape the final output back to rank 2 - build_reshape( - tosa_graph, mat_mul_result.name, output.shape, reshape_output_name - ) diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS index a2445f26c0..bbd7322daf 100644 --- a/backends/arm/quantizer/TARGETS +++ b/backends/arm/quantizer/TARGETS @@ -5,12 +5,22 @@ python_library( srcs = ["arm_quantizer.py"], deps = [ ":arm_quantizer_utils", + ":quantization_annotator", "//caffe2:torch", - "//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation", "//executorch/exir:lib", ], ) +python_library( + name = "quantization_annotator", + srcs = ["quantization_annotator.py"], + deps = [ + ":arm_quantizer_utils", + ":quantization_config", + "//caffe2:torch", + ], +) + python_library( name = "quantization_config", srcs = ["quantization_config.py"], diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 8815d40b0b..fe104db972 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -13,24 +13,16 @@ from __future__ import annotations -import copy import functools -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional import torch -import torch.nn.functional as F from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.arm_quantizer_utils import ( - mark_nodes_as_annotated, - propagate_annotation, -) -from executorch.backends.arm.quantizer.quantization_annotation import ( - OP_TO_ANNOTATOR, - OperatorConfig, - OperatorPatternType, -) +from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated +from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph + from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch.ao.quantization.fake_quantize import ( FakeQuantize, @@ -58,44 +50,6 @@ ] -def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: - supported_operators: Dict[str, List[OperatorPatternType]] = { - # Both conv and linear should be able to handle relu + hardtanh fusion since - # those are clamp ops - "conv2d": [ - [torch.nn.Conv2d, torch.nn.ReLU], - [torch.nn.Conv2d, F.relu], - [F.conv2d, torch.nn.ReLU], - [F.conv2d, F.relu], - ], - "linear": [[torch.nn.Linear], [F.linear]], - "add": [[torch.add]], - "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]], - "adaptive_avg_pool2d": [ - [torch.nn.AdaptiveAvgPool2d], - [F.adaptive_avg_pool2d], - ], - "mul": [[torch.mul]], - "sub": [[torch.sub]], - "min_max": [[torch.min], [torch.max]], - } - return copy.deepcopy(supported_operators) - - -def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: - supported_config_and_operators: List[OperatorConfig] = [] - for quantization_config in [ - get_symmetric_quantization_config(), - get_symmetric_quantization_config(is_per_channel=True), - ]: - ops = _supported_symmetric_quantized_operators() - for pattern_list in ops.values(): - supported_config_and_operators.append( - OperatorConfig(quantization_config, pattern_list) - ) - return copy.deepcopy(supported_config_and_operators) - - @functools.lru_cache def get_symmetric_quantization_config( is_per_channel: bool = False, @@ -180,10 +134,6 @@ def get_symmetric_quantization_config( return quantization_config -def _get_supported_config_and_operators() -> List[OperatorConfig]: - return _get_supported_symmetric_config_and_operators() - - NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. A Node filter is a function that takes a Node and returns whether the node should be annotated or not. @@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool: class ArmQuantizer(Quantizer): - supported_config_and_operators = _get_supported_config_and_operators() - - # A list of supported static quantization annotators, in order of application. - # For example, fusions come before singular ops. - # The name must match the name used when registering the annotator. - STATIC_ANNOTATION_ORDER = [ - "linear", - "conv", - "adaptive_avg_pool2d", - "max_pool2d", - "add", - "sub", - "mul", - "min_max", - "mm", - "one_to_one", - "generic", - "upsample_nearest2d", - ] - def __init__(self) -> None: super().__init__() self.global_config: Optional[QuantizationConfig] = None @@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule: The annotated model. """ model = self._annotate_for_static_quantization_config(model) - propagate_annotation(model) return model def _annotate_all_static_patterns( @@ -353,8 +282,7 @@ def _annotate_all_static_patterns( if quantization_config is None: return model - for op in self.STATIC_ANNOTATION_ORDER: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + annotate_graph(model, quantization_config, filter_fn) return model def _annotate_for_static_quantization_config( @@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config( """Matches the correct QuantizationConfig with the correct module using a filter when running _annotate_all_static_patterns. """ + if self.io_config: + self._annotate_io(model, self.io_config) + module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_static_patterns( @@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config( _get_not_module_type_or_name_filter(tp_list, module_name_list), ) - if self.io_config: - self._annotate_io(model, self.io_config) - return model def _annotate_io( @@ -399,44 +327,13 @@ def _annotate_io( node, quantization_config.get_output_act_qspec(), ) - mark_nodes_as_annotated([node]) + mark_node_as_annotated(node) if node.op == "output": parent = node.all_input_nodes[0] _annotate_input_qspec_map( node, parent, quantization_config.get_input_act_qspec() ) - mark_nodes_as_annotated([node]) + mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: pass - - @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: - return cls.supported_config_and_operators - - @classmethod - def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: - op_configs: Set[QuantizationConfig] = set({}) - for spec, _ in cls.supported_config_and_operators: - op_configs.add(spec) - return list(op_configs) - - @classmethod - def get_supported_operator_for_quantization_config( - cls, quantization_config: Optional[QuantizationConfig] - ) -> List[OperatorPatternType]: - if quantization_config is None: - all_ops = [] - for _, ops in cls.supported_config_and_operators: - all_ops.extend(ops) - return all_ops - - for config, ops in cls.supported_config_and_operators: - # note: this assumes each entry in cls.supported_spec_and_operators - # corresponds to one spec, e.g. we don't have - # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] - # where the first and second entry have the same spec but did not - # merge the op list - if config == quantization_config: - return ops - return [] diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 4d52b7ddf1..7b460ccae7 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -11,17 +11,12 @@ # Utility functions for ArmQuantizer # -import operator -from typing import Callable, cast, List +from typing import cast import torch -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch._subclasses import FakeTensor -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) +from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node @@ -35,72 +30,30 @@ def is_annotated(node: Node) -> bool: ) -def are_annotated(nodes: List[Node]) -> bool: - """Given a list of nodes (that represents an operator pattern), - return True if any of the nodes - is annotated, otherwise return False. - """ - for node in nodes: - if is_annotated(node): - return True - return False +def is_output_annotated(node: Node) -> bool: + """Given a node, return whether the output of the node is annotated.""" + if "quantization_annotation" in node.meta: + annotation = cast(QuantizationAnnotation, node.meta["quantization_annotation"]) + return annotation._annotated and annotation.output_qspec is not None + else: + return False -def mark_nodes_as_annotated(nodes: List[Node]) -> None: - """Marks all nodes in list 'nodes' as annotated. If needed, an empty - QuantizationAnnotation is added to the quantization_annotation node meta entry. - """ - for node in nodes: - if node is not None: - if "quantization_annotation" not in node.meta: - node.meta["quantization_annotation"] = QuantizationAnnotation() - node.meta["quantization_annotation"]._annotated = True - - -def get_shared_qspec( - node: Node, gm: GraphModule, quantization_config: QuantizationConfig -): - """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs - and output to the parameter 'node'. - Parameters: - node: a node with two inputs that should share Quantization parameters. - gm: The GraphModule containing the node. Used to inspect global graph features. - quantization_config : a QuantizationConfig with the input QuantizationSpec to share - Returns: - input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to - the correct QuantizationSpec. - shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec. - - Both outputs are None if one of the inputs is a node that can't be quantized. +def mark_node_as_annotated(node: Node) -> None: + """Marks node as annotated. If needed, an empty QuantizationAnnotation is added + to the quantization_annotation node meta entry. """ - input_act0 = cast(Node, node.args[0]) - input_act1 = node.args[1] - - input_act_qspec = quantization_config.get_input_act_qspec() - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) - - input_qspec_map = {} - if isinstance(input_act0, Node): - if not is_input_ok_for_quantization(input_act0, gm): - return None, None - input_qspec_map[input_act0] = input_act_qspec - - if isinstance(input_act1, Node): - if not is_input_ok_for_quantization(input_act1, gm): - return None, None - if input_act0 is not input_act1: - input_qspec_map[input_act1] = shared_with_input0_qspec - return input_qspec_map, shared_with_input0_qspec + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True -def is_input_ok_for_quantization(input_act: Node, gm: GraphModule): - """Check if an input can be quantized. The input can not be quantized if: +def is_ok_for_quantization(node: Node, gm: GraphModule): + """Check if an node can be quantized. The node can not be quantized if: - The node does not output a float tensor or, - The node outputs a large scalar. """ - return not ( - is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm) - ) + return not (is_non_float_tensor(node) or is_large_scalar(node, gm)) def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): @@ -110,7 +63,7 @@ def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): return getattr(module, targets[-1]) -def is_input_large_scalar(node: Node, gm: GraphModule): +def is_large_scalar(node: Node, gm: GraphModule): """Check if input is a large scalar value. So that we can skip quantization for the node since histc op (in HistogramObserver) only works for values up to certain upper bound """ @@ -122,72 +75,10 @@ def is_input_large_scalar(node: Node, gm: GraphModule): return False -def is_input_non_float_tensor(node: Node) -> bool: +def is_non_float_tensor(node: Node) -> bool: """Check if the input is not a float tensor, so that we can skip quantization for the node since observers only works with float Tensors """ if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return True return node.meta["val"].dtype != torch.float32 - - -def is_share_obs_or_fq_op(op: Callable) -> bool: - """Returns whether the the operation 'op' can be quantized using a shared observer or - fake quantizer. This means that the operation can inherit it's quantization spec - from parent nodes. - """ - return op in [ - torch.ops.aten.hardtanh.default, - torch.ops.aten.hardtanh_.default, - torch.ops.aten.relu.default, - torch.ops.aten.mean.default, - torch.ops.aten.mean.dim, - torch.ops.aten.permute.default, - torch.ops.aten.permute_copy.default, - # TODO: remove? - torch.ops.aten.adaptive_avg_pool2d.default, - torch.ops.aten.avg_pool2d.default, - torch.ops.aten.max_pool2d.default, - torch.ops.aten.full.default, - torch.ops.aten.flatten.using_ints, - torch.ops.aten.dropout.default, - operator.getitem, - ] - - -def propagate_annotation(model: GraphModule) -> None: - """For unannotated ops that can share observer or have fake quantizers, - annotate with a SharedQuantizationSpec, where the shared spec is the - output spec of the parent node. - This propagates output qspecs downward in the graph until - an op that is already annotated or can't share qspec is encountered. - """ - for n in model.graph.nodes: - n = cast(Node, n) - if is_annotated(n): - continue - if n.op != "call_function" or not is_share_obs_or_fq_op( - cast(Callable, n.target) - ): - continue - - prev_node = n.args[0] - if not isinstance(prev_node, Node): - continue - - quantization_annotation = cast( - QuantizationAnnotation | None, - prev_node.meta.get("quantization_annotation", None), - ) - if not quantization_annotation or not quantization_annotation.output_qspec: - continue - - # propagate the previous output_qspec to the current node - shared_qspec = SharedQuantizationSpec(prev_node) - n.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - prev_node: shared_qspec, - }, - output_qspec=shared_qspec, - _annotated=True, - ) diff --git a/backends/arm/quantizer/quantization_annotation/TARGETS b/backends/arm/quantizer/quantization_annotation/TARGETS deleted file mode 100644 index 4ce8b5cad2..0000000000 --- a/backends/arm/quantizer/quantization_annotation/TARGETS +++ /dev/null @@ -1,12 +0,0 @@ -load("@fbcode_macros//build_defs:python_library.bzl", "python_library") - -python_library( - name = "quantization_annotation", - srcs = glob(["*.py"]), - typing = True, - deps = [ - "//caffe2:torch", - "//executorch/backends/arm/quantizer:arm_quantizer_utils", - "//executorch/backends/arm/quantizer:quantization_config", - ], -) diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py deleted file mode 100644 index d9d27cee2a..0000000000 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -from typing import Callable, Dict, List, NamedTuple, Optional - -import torch -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.fx import Node - -OperatorPatternType = List[Callable] -OperatorPatternType.__module__ = "executorch.backends.arm.quantizer.arm_quantizer_utils" - - -class OperatorConfig(NamedTuple): - # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] - # Basically we are mapping a quantization config to some list of patterns. - # a pattern is defined as a list of nn module, function or builtin function names - # e.g. [nn.Conv2d, torch.relu, torch.add] - # We have not resolved whether fusion can be considered internal details of the - # quantizer hence it does not need communication to user. - # Note this pattern is not really informative since it does not really - # tell us the graph structure resulting from the list of ops. - config: QuantizationConfig - operators: List[OperatorPatternType] - - -AnnotatorType = Callable[ - [ - torch.fx.GraphModule, - QuantizationConfig, - Optional[Callable[[Node], bool]], - ], - Optional[List[List[Node]]], -] -OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {} - - -def register_annotator(op: str): - def decorator(annotator: AnnotatorType): - OP_TO_ANNOTATOR[op] = annotator - - return decorator - - -from . import ( # noqa - adaptive_ang_pool2d_annotator, - add_annotator, - conv_annotator, - generic_annotator, - linear_annotator, - max_pool2d_annotator, - min_max_annotator, - mm_annotator, - mul_annotator, - one_to_one_annotator, - sub_annotator, - upsample_nearest2d_annotator, -) diff --git a/backends/arm/quantizer/quantization_annotation/adaptive_ang_pool2d_annotator.py b/backends/arm/quantizer/quantization_annotation/adaptive_ang_pool2d_annotator.py deleted file mode 100644 index 723a48f664..0000000000 --- a/backends/arm/quantizer/quantization_annotation/adaptive_ang_pool2d_annotator.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -import torch.nn.functional as F -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("adaptive_avg_pool2d") -def _annotate_adaptive_avg_pool2d( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - """Always annotate adaptive_avg_pool2d op""" - module_partitions = get_source_partitions( - gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn - ) - partitions = list(itertools.chain.from_iterable(module_partitions.values())) - annotated_partitions = [] - for partition in partitions: - pool_node = partition.output_nodes[0] - if ( - pool_node.op != "call_function" - or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default - ): - raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") - - if arm_quantizer_utils.is_annotated(pool_node): - continue - - annotated_partitions.append(partition.nodes) - input_act = pool_node.args[0] - assert isinstance(input_act, Node) - - # only annotate input output sharing operator - # when the output of the input node is annotated - if ( - "quantization_annotation" not in input_act.meta - or not input_act.meta["quantization_annotation"]._annotated - or input_act.meta["quantization_annotation"].output_qspec is None - ): - input_act_qspec = quantization_config.get_input_act_qspec() - else: - input_act_qspec = SharedQuantizationSpec(input_act) - - # output sharing with input - output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) - pool_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - input_act: input_act_qspec, - }, - output_qspec=output_act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/add_annotator.py b/backends/arm/quantizer/quantization_annotation/add_annotator.py deleted file mode 100644 index 600c5d31f6..0000000000 --- a/backends/arm/quantizer/quantization_annotation/add_annotator.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import Node - - -@register_annotator("add") -def _annotate_add( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in ( - torch.ops.aten.add.Tensor, - torch.ops.aten.add_.Tensor, - ): - continue - annotated_partitions.append(node) - add_node = node - if arm_quantizer_utils.is_annotated(add_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - add_node, gm, quantization_config - ) - if input_qspec_map is not None: - add_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/conv_annotator.py b/backends/arm/quantizer/quantization_annotation/conv_annotator.py deleted file mode 100644 index 4ff7dd9e80..0000000000 --- a/backends/arm/quantizer/quantization_annotation/conv_annotator.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree.f - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation - -from torch.fx import Node - - -@register_annotator("conv") -def _annotate_conv( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for n in gm.graph.nodes: - if n.op != "call_function" or n.target not in [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - ]: - continue - conv_node = n - - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.get_input_act_qspec() - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = quantization_config.get_weight_qspec() - - # adding weight node to the partition as well - partition_nodes = [conv_node, conv_node.args[1]] - - bias = conv_node.args[2] if len(conv_node.args) > 2 else None - if isinstance(bias, Node): - input_qspec_map[bias] = quantization_config.get_bias_qspec() - partition_nodes.append(bias) - - if arm_quantizer_utils.are_annotated(partition_nodes): - continue - - if filter_fn and any(not filter_fn(n) for n in partition_nodes): - continue - - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.get_output_act_qspec(), - _annotated=True, - ) - arm_quantizer_utils.mark_nodes_as_annotated(partition_nodes) - annotated_partitions.append(partition_nodes) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py deleted file mode 100644 index 5db29c95e8..0000000000 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Callable, List, Optional - -import torch -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import SharedQuantizationSpec -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -_SUPPORTED_OPS = [ - # DATA LAYOUT OPS - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze_copy.default, - torch.ops.aten.squeeze_copy.dim, - torch.ops.aten.squeeze.dim, - torch.ops.aten.squeeze.dims, - torch.ops.aten.unsqueeze.default, - torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten.reshape.default, - torch.ops.aten.repeat.default, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand.default, - # Disabling these as there seems to be an issue with support for complex - # datatypes in torch: - # torch.ops.aten.view_as_complex.default, - # torch.ops.aten.view_as_complex_copy.default, - # torch.ops.aten.view_as_real.default, - # torch.ops.aten.view_as_real_copy.default, - torch.ops.aten.view.default, - torch.ops.aten.view_as.default, - torch.ops.aten.view_copy.default, - torch.ops.aten.select.int, - torch.ops.aten.select_copy.int, - torch.ops.aten.slice.Tensor, - torch.ops.aten.slice_copy.Tensor, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes.default, - torch.ops.aten.transpose.Dimname, - torch.ops.aten.transpose.int, - torch.ops.aten.transpose_copy.int, - torch.ops.aten.tile.default, - torch.ops.aten.flip.default, - torch.ops.aten.cat.default, - torch.ops.aten.concatenate.default, - torch.ops.aten.stack.default, - torch.ops.aten.chunk.default, - torch.ops.aten.contiguous.default, -] - - -@register_annotator("generic") -def _annotate_generic( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - """Propagate qspecs to generic ops like unsqueeze, reshape etc.""" - annotated_partitions = [] - - for node in gm.graph.nodes: - if node.op != "call_function" or node.target not in _SUPPORTED_OPS: - continue - if filter_fn and not filter_fn(node): - continue - if arm_quantizer_utils.is_annotated(node): - continue - - input_acts = node.args[0] - - # Check to see if there are multiple inputs. - # this allows for stack/cat ops to be annotated - # in a similar way. - has_multi_inputs = isinstance(input_acts, list) - - input_act0 = input_acts[0] if has_multi_inputs else input_acts - - # Using a non-shared quantization spec here as a SharedQuantizationSpec - # can lead to a recursion. - _annotate_input_qspec_map( - node, input_act0, quantization_config.get_input_act_qspec() - ) - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) - - if has_multi_inputs: - # For the rest of the inputs, share qspec with first. - for input_act in input_acts[1:]: - if input_act is not input_act0: - node.meta["quantization_annotation"].input_qspec_map[ - input_act - ] = shared_with_input0_qspec - - _annotate_output_qspec(node, shared_with_input0_qspec) - arm_quantizer_utils.mark_nodes_as_annotated([node]) - annotated_partitions.append([node]) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/linear_annotator.py b/backends/arm/quantizer/quantization_annotation/linear_annotator.py deleted file mode 100644 index 7c3f91ec70..0000000000 --- a/backends/arm/quantizer/quantization_annotation/linear_annotator.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -@register_annotator("linear") -def _annotate_linear( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = quantization_config.get_output_act_qspec() - weight_qspec = quantization_config.get_weight_qspec() - bias_qspec = quantization_config.get_bias_qspec() - - for node in gm.graph.nodes: - if node.op != "call_function" or node.target != torch.ops.aten.linear.default: - continue - if filter_fn and not filter_fn(node): - continue - act_node = node.args[0] - weight_node = node.args[1] - bias_node = None - if len(node.args) > 2: - bias_node = node.args[2] - - if arm_quantizer_utils.is_annotated(node) is False: # type: ignore[list-item] - _annotate_input_qspec_map( - node, - act_node, - input_act_qspec, - ) - _annotate_input_qspec_map( - node, - weight_node, - weight_qspec, - ) - nodes_to_mark_annotated = [node, weight_node] - if bias_node: - _annotate_input_qspec_map( - node, - bias_node, - bias_qspec, - ) - nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, output_act_qspec) - arm_quantizer_utils.mark_nodes_as_annotated(nodes_to_mark_annotated) - annotated_partitions.append(nodes_to_mark_annotated) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/max_pool2d_annotator.py b/backends/arm/quantizer/quantization_annotation/max_pool2d_annotator.py deleted file mode 100644 index 0ef2ee39fe..0000000000 --- a/backends/arm/quantizer/quantization_annotation/max_pool2d_annotator.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("max_pool2d") -def _annotate_max_pool2d( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - module_partitions = get_source_partitions( - gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn - ) - maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values())) - annotated_partitions = [] - for maxpool_partition in maxpool_partitions: - annotated_partitions.append(maxpool_partition.nodes) - output_node = maxpool_partition.output_nodes[0] - maxpool_node = None - for n in maxpool_partition.nodes: - if n.target == torch.ops.aten.max_pool2d.default: - maxpool_node = n - assert ( - maxpool_node is not None - ), "ArmQuantizer only works with torch.ops.aten.max_pool2d.default, " - "please make sure you are exporting the model correctly" - if arm_quantizer_utils.are_annotated([output_node, maxpool_node]): # type: ignore[list-item] - continue - - input_act = maxpool_node.args[0] # type: ignore[union-attr] - assert isinstance(input_act, Node) - - # only annotate maxpool when the output of the input node is annotated - if ( - "quantization_annotation" not in input_act.meta - or not input_act.meta["quantization_annotation"]._annotated - or input_act.meta["quantization_annotation"].output_qspec is None - ): - continue - # input and output of maxpool will share quantization parameter with input of maxpool - act_qspec = SharedQuantizationSpec(input_act) - # act_qspec = get_act_qspec(quantization_config) - maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] - input_qspec_map={ - input_act: act_qspec, - }, - _annotated=True, - ) - output_node.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py deleted file mode 100644 index 43c4d20c13..0000000000 --- a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import GraphModule, Node - - -@register_annotator("min_max") -def _annotate_min_max( - gm: GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in ( - torch.ops.aten.minimum.default, - torch.ops.aten.maximum.default, - ): - continue - annotated_partitions.append(node) - min_max_node = node - if arm_quantizer_utils.is_annotated(min_max_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - min_max_node, gm, quantization_config - ) - if input_qspec_map is not None: - min_max_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py deleted file mode 100644 index 60d9adb1c3..0000000000 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("mm") -def _annotate_mm( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions( - gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn - ) - mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) - annotated_partitions = [] - for mm_partition in mm_partitions: - annotated_partitions.append(mm_partition.nodes) - mm_node = mm_partition.output_nodes[0] - - if arm_quantizer_utils.is_annotated(mm_node): - continue - - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = quantization_config.get_output_act_qspec() - - input_qspec_map = {} - input_act0 = mm_node.args[0] - if isinstance(input_act0, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm): - continue - input_qspec_map[input_act0] = input_act_qspec - - input_act1 = mm_node.args[1] - if isinstance(input_act1, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act1, gm): - continue - input_qspec_map[input_act1] = input_act_qspec - - mm_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/mul_annotator.py b/backends/arm/quantizer/quantization_annotation/mul_annotator.py deleted file mode 100644 index 3a206c3aba..0000000000 --- a/backends/arm/quantizer/quantization_annotation/mul_annotator.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import Node - - -@register_annotator("mul") -def _annotate_mul( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor): - continue - mul_node = node - annotated_partitions.append([mul_node]) - if arm_quantizer_utils.is_annotated(mul_node): - continue - - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = quantization_config.get_output_act_qspec() - - input_qspec_map = {} - input_act0 = mul_node.args[0] - if isinstance(input_act0, Node): - if arm_quantizer_utils.is_input_large_scalar(input_act0, gm): - continue - if arm_quantizer_utils.is_input_non_float_tensor(input_act0): - continue - input_qspec_map[input_act0] = input_act_qspec - - input_act1 = mul_node.args[1] - if isinstance(input_act1, Node): - if arm_quantizer_utils.is_input_large_scalar(input_act1, gm): - continue - if arm_quantizer_utils.is_input_non_float_tensor(input_act1): - continue - input_qspec_map[input_act1] = input_act_qspec - - mul_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py deleted file mode 100644 index ec008826f0..0000000000 --- a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -@register_annotator("one_to_one") -def _annotate_one_to_one( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - """ - This annotator adds the input and output qspec from the quantization config to - ops in 'one_to_one_ops' that have the following properties: - - Have a single input and single output. - - Can handle different qspecs on the input and output. - - Typical ops are ops implemented with a lookup table. - """ - annotated_partitions = [] - one_to_one_ops = ( - torch.ops.aten.exp.default, - torch.ops.aten.log.default, - torch.ops.aten.reciprocal.default, - torch.ops.aten.rsqrt.default, - torch.ops.aten.sigmoid.default, - torch.ops.aten.tanh.default, - torch.ops.aten.sum.dim_IntList, - ) - for node in gm.graph.nodes: - if node.op != "call_function" or node.target not in one_to_one_ops: - continue - if filter_fn and not filter_fn(node): - continue - input_node = node.args[0] - - if not arm_quantizer_utils.is_annotated(node): - _annotate_input_qspec_map( - node, - input_node, - quantization_config.get_input_act_qspec(), - ) - _annotate_output_qspec(node, quantization_config.get_output_act_qspec()) - - arm_quantizer_utils.mark_nodes_as_annotated([node]) - annotated_partitions.append([node]) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/sub_annotator.py b/backends/arm/quantizer/quantization_annotation/sub_annotator.py deleted file mode 100644 index 437f3e22e7..0000000000 --- a/backends/arm/quantizer/quantization_annotation/sub_annotator.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import GraphModule, Node - - -@register_annotator("sub") -def _annotate_sub( - gm: GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor): - continue - annotated_partitions.append(node) - sub_node = node - if arm_quantizer_utils.is_annotated(sub_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - sub_node, gm, quantization_config - ) - if input_qspec_map is not None: - sub_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py b/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py deleted file mode 100644 index eb4daf3a5c..0000000000 --- a/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None): - def filter(node: Node): - is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec - if filter_fn is None: - return is_upsample - else: - return is_upsample and filter_fn(node) - - return filter - - -@register_annotator("upsample_nearest2d") -def _annotate_upsample_nearest2d( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - module_partitions = get_source_partitions( - gm.graph, - [ - torch.nn.UpsamplingNearest2d, - torch.nn.Upsample, - torch.nn.functional.interpolate, - ], - _filter_upsample_nearest2d(filter_fn), - ) - upsample_partitions = list( - itertools.chain.from_iterable(module_partitions.values()) - ) - annotated_partitions = [] - - for upsample_partition in upsample_partitions: - annotated_partitions.append(upsample_partition.nodes) - - assert len(upsample_partition.nodes) == 1 - upsample_node = upsample_partition.nodes[0] - - input_act = upsample_node.args[0] - assert isinstance(input_act, Node) - - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = SharedQuantizationSpec((input_act, upsample_node)) - - upsample_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - input_act: input_act_qspec, - }, - output_qspec=output_act_qspec, - _annotated=True, - ) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py new file mode 100644 index 0000000000..9ddeb61c30 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotator.py @@ -0,0 +1,362 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch +import torch.fx +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node + + +@dataclass(frozen=True) +class _QuantProperty: + """Specify how the input/output at 'index' must be quantized.""" + + index: int + qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]] + optional: bool = False + mark_annotated: bool = False + + +class _OpQuantProperties: + def __init__(self): + self.quant_inputs: List[_QuantProperty] = [] + self.quant_output: Optional[_QuantProperty] = None + + +def _as_list(x): + if isinstance(x, list): + return x + else: + return [ + x, + ] + + +def _is_ok_for_quantization( + node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule +) -> bool: + if quant_property.optional and ( + quant_property.index >= len(node.args) + or node.args[quant_property.index] is None + ): + return True + + for n_arg in _as_list(node.args[quant_property.index]): + assert isinstance(n_arg, Node) + if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): + return False + + return True + + +def _annotate_input(node: Node, quant_property: _QuantProperty): + assert not arm_quantizer_utils.is_annotated(node) + if quant_property.optional and ( + quant_property.index >= len(node.args) + or node.args[quant_property.index] is None + ): + return + + for n_arg, qspec in zip( + _as_list(node.args[quant_property.index]), + _as_list(quant_property.qspec), + strict=True, + ): + assert isinstance(n_arg, Node) + _annotate_input_qspec_map(node, n_arg, qspec) + if quant_property.mark_annotated: + arm_quantizer_utils.mark_node_as_annotated(n_arg) + + +def _annotate_output(node: Node, quant_property: _QuantProperty): + assert not arm_quantizer_utils.is_annotated(node) + assert not quant_property.mark_annotated + assert not quant_property.optional + assert quant_property.index == 0, "Only one output annotation supported currently" + + _annotate_output_qspec(node, quant_property.qspec) + + +def _match_pattern( + node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None +) -> bool: + """ + Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the + chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the + chain pass the filtering. + + Each 'pattern' element is composed of a list of disjunctive nodes types. + """ + assert len(pattern) == 2, "Only two-nodes patterns supported currently" + + if node.target in pattern[0]: + assert len(node.users) != 0 + parent = node + child = next(iter(node.users)) + elif node.target in pattern[1]: + assert len(node.args) != 0 + parent = node.args[0] + child = node + else: + return False + + if len(parent.users) != 1: + return False + + if parent.target not in pattern[0] or child.target not in pattern[1]: + return False + + if filter_fn is not None: + return filter_fn(parent) and filter_fn(child) + + return True + + +_one_to_one = [ + torch.ops.aten.exp.default, + torch.ops.aten.log.default, + torch.ops.aten.reciprocal.default, + torch.ops.aten.rsqrt.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.tanh.default, + torch.ops.aten.sum.dim_IntList, +] + +_one_to_one_shared_input_qspec = [ + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.reshape.default, + torch.ops.aten.repeat.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, + # Disabling these as there seems to be an issue with support for complex + # datatypes in torch: + # torch.ops.aten.view_as_complex.default, + # torch.ops.aten.view_as_complex_copy.default, + # torch.ops.aten.view_as_real.default, + # torch.ops.aten.view_as_real_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.view_as.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.select.int, + torch.ops.aten.select_copy.int, + torch.ops.aten.slice.Tensor, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_copy.int, + torch.ops.aten.tile.default, + torch.ops.aten.flip.default, + torch.ops.aten.chunk.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.upsample_nearest2d.vec, +] + +# Operators that can inherit the quantization specs from its parent node +# as SharedQuantizationSpec. +_parent_shared_qspec = [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.relu.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.full.default, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.dropout.default, + operator.getitem, +] + + +def get_quant_properties( # noqa: C901 + node: Node, gm: torch.fx.GraphModule, quantization_config +) -> _OpQuantProperties: + input_act_qspec = quantization_config.get_input_act_qspec() + weight_qspec = quantization_config.get_weight_qspec() + output_act_qspec = quantization_config.get_output_act_qspec() + bias_qspec = quantization_config.get_bias_qspec() + + quant_properties = _OpQuantProperties() + + def any_or_hardtanh_min_zero(n: Node): + # Check that if the node is a hardtanh, its min_val is zero + return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0 + + if _match_pattern( + node, + [ + [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ], + [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], + ], + any_or_hardtanh_min_zero, + ): + if node.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), + ] + else: + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), + ] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in ( + torch.ops.aten.matmul.default, + torch.ops.aten.mm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, input_act_qspec), + ] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in ( + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.sub_.Tensor, + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + ): + shared_qspec = SharedQuantizationSpec((node.args[0], node)) + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty( + 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec + ), + ] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target == torch.ops.aten.adaptive_avg_pool2d.default: + input_qspec = ( + SharedQuantizationSpec(node.args[0]) + if arm_quantizer_utils.is_output_annotated(node.args[0]) + else input_act_qspec + ) + quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] + quant_properties.quant_output = _QuantProperty( + 0, SharedQuantizationSpec((node.args[0], node)) + ) + elif node.target in ( + torch.ops.aten.cat.default, + torch.ops.aten.concatenate.default, + torch.ops.aten.stack.default, + ): + assert isinstance(node.args[0], list) + assert len(node.args[0]) != 0 + + shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) + quant_properties.quant_inputs = [ + _QuantProperty( + 0, + [ + input_act_qspec if n == node.args[0][0] else shared_qspec + for n in node.args[0] + ], + ) + ] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target in _one_to_one: + quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in _one_to_one_shared_input_qspec: + quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] + quant_properties.quant_output = _QuantProperty( + 0, SharedQuantizationSpec((node.args[0], node)) + ) + elif node.target in _parent_shared_qspec: + if not isinstance(node.args[0], Node): + return None + + if not arm_quantizer_utils.is_output_annotated(node.args[0]): + return None + + shared_qspec = SharedQuantizationSpec(node.args[0]) + quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + else: + return None + + # Don't check if operator.getitem is ok for quantization, it's always ok + if node.target == operator.getitem: + return quant_properties + + # Check that each inputs/outputs can be quantized properly with the + # provided QuantProperties + for quant_property in quant_properties.quant_inputs: + if not _is_ok_for_quantization(node, quant_property, gm): + return None + + if quant_properties.quant_output is not None: + if not _is_ok_for_quantization(node, quant_properties.quant_output, gm): + return None + + return quant_properties + + +def annotate_graph( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + for node in gm.graph.nodes: + if node.op != "call_function": + continue + + if arm_quantizer_utils.is_annotated(node): + continue + + if filter_fn is not None and not filter_fn(node): + continue + + quant_properties = get_quant_properties(node, gm, quantization_config) + if quant_properties is None: + continue + + for quant_property in quant_properties.quant_inputs: + _annotate_input(node, quant_property) + + if quant_properties.quant_output is not None: + _annotate_output(node, quant_properties.quant_output) + + arm_quantizer_utils.mark_node_as_annotated(node) diff --git a/backends/arm/test/.coveragerc b/backends/arm/test/.coveragerc new file mode 100644 index 0000000000..4172a3d8fc --- /dev/null +++ b/backends/arm/test/.coveragerc @@ -0,0 +1,8 @@ +[run] +omit = + *__init__.py* + +[report] +skip_covered = true +exclude_also = + raise NotImplementedError diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 17083b129f..4a5615f97c 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module): ] test_data = [ - (20 * torch.randn(1, 3, 256, 256),), - (5 * torch.randn(1, 3, 256, 256),), + (2 * torch.randn(1, 3, 256, 256),), + (0.5 * torch.randn(1, 3, 256, 256),), (torch.randn(1, 3, 256, 256),), - (-5 * torch.randn(1, 3, 256, 256),), + (-0.5 * torch.randn(1, 3, 256, 256),), + (-2 * torch.randn(1, 3, 256, 256),), ] def __init__(self): diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 084eec138f..915b1fe7e0 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -34,12 +34,13 @@ class TestSimpleExpand(unittest.TestCase): class Expand(torch.nn.Module): # (input tensor, multiples) test_parameters = [ - (torch.ones(1), (2,)), - (torch.ones(1, 4), (1, -1)), - (torch.ones(1, 1, 2, 2), (4, 3, -1, 2)), - (torch.ones(1), (2, 2, 4)), - (torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)), - (torch.ones(1, 1, 192), (1, -1, -1)), + (torch.rand(1), (2,)), + (torch.randn(1, 4), (1, -1)), + (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)), + (torch.randn(1), (2, 2, 4)), + (torch.rand(3, 2, 4, 1), (-1, -1, -1, 3)), + (torch.randn(1, 1, 192), (1, -1, -1)), + (torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)), ] def forward(self, x: torch.Tensor, multiples: Sequence): diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 9ae1a27cf7..b9dc4e394d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -178,7 +178,7 @@ def __init__( self.output_name: str = None self.qp_input: list[QuantizationParams] = None self.qp_output: QuantizationParams = None - self.timeout = 120 + self.timeout = 480 self.target_board: str = None self._has_init_run = False @@ -316,7 +316,7 @@ def run_corstone( result = _run_cmd(command_args[self.target_board], check=False) if result.returncode != 0: raise RuntimeError( - f"Failed to run {command_args[self.target_board]}\nError: {result.stderr.decode()}" + f"Failed to run {command_args[self.target_board]}\nOutput:\n{result.stdout.decode()}\nError: {result.stderr.decode()}" ) result_stdout = result.stdout.decode() diff --git a/backends/arm/test/setup_testing.sh b/backends/arm/test/setup_testing.sh index 0562604dd9..3681917865 100755 --- a/backends/arm/test/setup_testing.sh +++ b/backends/arm/test/setup_testing.sh @@ -14,7 +14,6 @@ ethos_u_root_dir=${et_root_dir}/examples/arm/ethos-u-scratch/ethos-u toolchain_cmake=${et_root_dir}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake et_build_dir=${et_root_dir}/cmake-out build_root_test_dir=${et_build_dir}/arm_semihosting_executor_runner -fvp_model=FVP_Corstone_SSE-300_Ethos-U55 # Build Arm Baremetal executor_runner in semihosting mode. # Put in backends/arm/test/res to be used by unit tests. diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh new file mode 100755 index 0000000000..d42b65e6c6 --- /dev/null +++ b/backends/arm/test/test_arm_baremetal.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) + +# Executorch root +et_root_dir=$(cd ${script_dir}/../../.. && pwd) +cd "${et_root_dir}" +pwd + + +TEST_SUITE=$1 + +help() { + echo "Usage:" + echo " $0 " + echo " where can be any of:" + # This will list all lines in this file that is starting with test_ remove () { and print it as a list. + # e,g, "test_pytest() { # Test ops and other things" -> test_pytest # Test ops and other things + echo "all # run all tests" + grep "^test_" $0 | sed 's/([^)]*)[[:space:]]*{*//g' + exit +} + +if [[ -z "${TEST_SUITE:-}" ]]; then + echo "Missing test suite name, exiting..." + help +else + echo "Run Arm baremetal test suite ${TEST_SUITE}" +fi + +TEST_SUITE_NAME="$(basename "$0") ${TEST_SUITE}" + +all() { # Run all tests + # This will list all lines in this file that is starting with test_ remove () { and add this script name in + # front of it and execute it in a sub shell + # e.g. from this file: + # + # test_pytest() { # Test ops and other things + # bla bla bla + # } + # test_pytest_ethosu_fvp() { # Same as test_pytest but ... + # bla bla bla + # } + #... + # become a small script: + # ---- + # backends/arm/test/test_arm_baremetal.sh test_pytest # Test ops and other things + # backends/arm/test/test_arm_baremetal.sh test_pytest_ethosu_fvp # Same as test_pytest but ... + # ... + # ---- + # That is executed + echo "${TEST_SUITE_NAME}: Run all tests" + grep "^test_" backends/arm/test/test_arm_baremetal.sh | sed 's/([^)]*)[[:space:]]*{*//g' | sed "s|^|$0 |" | sh +} + +test_pytest() { # Test ops and other things + echo "${TEST_SUITE_NAME}: Run pytest" + cd "${et_root_dir}" + source examples/arm/ethos-u-scratch/setup_path.sh + + # Run arm baremetal pytest tests without FVP + pytest --verbose --color=yes --numprocesses=auto backends/arm/test/ +} + +test_pytest_ethosu_fvp() { # Same as test_pytest but also sometime verify using Corstone FVP + echo "${TEST_SUITE_NAME}: Run pytest with fvp" + + source examples/arm/ethos-u-scratch/setup_path.sh + + # Prepare Corstone-3x0 FVP for pytest + examples/arm/run.sh --model_name=add --build_only + backends/arm/test/setup_testing.sh + + # Run arm baremetal pytest tests with FVP + pytest --verbose --color=yes --numprocesses=auto backends/arm/test/ --arm_quantize_io --arm_run_corstoneFVP +} + +test_run_ethosu_fvp() { # End to End model tests + echo "${TEST_SUITE_NAME}: Test ethos-u delegate examples with run.sh" + + source examples/arm/ethos-u-scratch/setup_path.sh + + # TOSA quantized + echo "${TEST_SUITE_NAME}: Test ethos-u target TOSA" + examples/arm/run.sh --target=TOSA --model_name=mv2 + examples/arm/run.sh --target=TOSA --model_name=lstm + examples/arm/run.sh --target=TOSA --model_name=edsr + + # Ethos-U55 + echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U55" + examples/arm/run.sh --target=ethos-u55-128 --model_name=mv2 + examples/arm/run.sh --target=ethos-u55-128 --model_name=lstm --reorder_inputs=1,0,2 + + # Ethos-U85 + echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85" + examples/arm/run.sh --target=ethos-u85-128 --model_name=mv2 + examples/arm/run.sh --target=ethos-u85-128 --model_name=lstm --reorder_inputs=1,0,2 + } + +${TEST_SUITE} \ No newline at end of file diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index f663b16a9d..b3f5b4f05b 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -251,7 +251,7 @@ def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = Non return super().to_executorch(to_executorch_stage) def serialize( - self, serialize_stage: Optional[Serialize] = None, timeout: int = 120 + self, serialize_stage: Optional[Serialize] = None, timeout: int = 480 ): if serialize_stage is None: serialize_stage = Serialize(self.runner_util, timeout=timeout) diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 6c71909c47..e2ac3de5ca 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -60,9 +60,6 @@ if(EXECUTORCH_CADENCE_CPU_RUNNER) ${_common_include_directories} ) - set(TARGET_DIR reference) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) - target_link_libraries( cadence_runner executorch @@ -78,12 +75,12 @@ endif() if(EXECUTORCH_NNLIB_OPT) set(TARGET_DIR hifi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) -endif() - -if(EXECUTORCH_FUSION_G3_OPT) +elseif(EXECUTORCH_FUSION_G3_OPT) set(TARGET_DIR fusion_g3) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib) +else() + set(TARGET_DIR reference) endif() +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/operators) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index c19a4296f6..bd0a45227e 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -17,7 +17,10 @@ print_memory_planning_info, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion -from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer +from executorch.backends.cadence.aot.quantizer.quantizer import ( + CadenceDefaultQuantizer, + CadenceQuantizer, +) from executorch.backends.cadence.aot.utils import ( get_default_memory_config, MemoryConfig, @@ -136,7 +139,7 @@ def quantize_pt2( # Instantiate the quantizer to CadenceQuantizer if not supplied if not quantizer: - quantizer = CadenceQuantizer() + quantizer = CadenceDefaultQuantizer() # Get converted graph module converted_gm = convert_pt2(model, inputs, quantizer) diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 4ba5bffc96..28a1a60a2a 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -20,7 +20,7 @@ fuse_pt2, ) -from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer from executorch.backends.cadence.runtime import runtime from executorch.backends.cadence.runtime.executor import BundledProgramManager from executorch.exir import ExecutorchProgramManager @@ -74,7 +74,7 @@ def export_model( ) # Instantiate the quantizer - quantizer = CadenceQuantizer(qconfig) + quantizer = CadenceDefaultQuantizer(qconfig) # Convert the model converted_model = convert_pt2(model, example_inputs, quantizer) diff --git a/backends/cadence/aot/functions_fusion_g3.yaml b/backends/cadence/aot/functions_fusion_g3.yaml index f1f934b970..5ca0554480 100644 --- a/backends/cadence/aot/functions_fusion_g3.yaml +++ b/backends/cadence/aot/functions_fusion_g3.yaml @@ -50,12 +50,12 @@ - op: div.out kernels: - arg_meta: null - kernel_name: torch::executor::div_out + kernel_name: cadence::impl::G3::div_out - op: div.out_mode kernels: - arg_meta: null - kernel_name: torch::executor::div_out_mode + kernel_name: cadence::impl::G3::div_out_mode - op: embedding.out kernels: @@ -71,7 +71,6 @@ kernels: - arg_meta: null kernel_name: cadence::impl::G3::mul_out - - op: mul.Scalar_out kernels: - arg_meta: null @@ -80,7 +79,7 @@ - op: permute_copy.out kernels: - arg_meta: null - kernel_name: torch::executor::permute_copy_out + kernel_name: cadence::impl::G3::permute_copy_out - op: sigmoid.out kernels: @@ -90,7 +89,7 @@ - op: slice_copy.Tensor_out kernels: - arg_meta: null - kernel_name: torch::executor::slice_copy_Tensor_out + kernel_name: cadence::impl::G3::slice_copy_Tensor_out - op: split_with_sizes_copy.out kernels: @@ -100,7 +99,12 @@ - op: sub.out kernels: - arg_meta: null - kernel_name: torch::executor::sub_out + kernel_name: cadence::impl::G3::sub_out + +- op: sub.Scalar_out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::sub_scalar_out - op: view_copy.out kernels: @@ -117,6 +121,16 @@ - arg_meta: null kernel_name: cadence::impl::G3::native_layer_norm_out +- op: mean.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::mean_dim_out + +- op: exp.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::exp_out + # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 73ca40c9aa..65979919ed 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -60,6 +60,13 @@ bias_qspec: Optional[QuantizationSpec] = None +_default_qconfig = QuantizationConfig( + act_qspec, + act_qspec, + wgt_qspec, + None, +) + class CadenceAtenQuantizer(Quantizer): def __init__( @@ -140,31 +147,39 @@ def get_supported_operators(cls) -> List[OperatorConfig]: return [] +def get_cadence_default_quantizer_list_with_config( + quantization_config: QuantizationConfig, +) -> List[Quantizer]: + return [ + CadenceAtenQuantizer(AddmmPattern(), quantization_config), + CadenceAtenQuantizer(BmmPattern(), quantization_config), + CadenceAtenQuantizer(Conv1dPattern(), quantization_config), + CadenceAtenQuantizer(Conv2dPattern(), quantization_config), + CadenceAtenQuantizer(LayerNormPattern(), quantization_config), + CadenceAtenQuantizer(LinearPattern(), quantization_config), + CadenceAtenQuantizer(MatmulPattern(), quantization_config), + CadenceAtenQuantizer(ReluPattern0(), quantization_config), + CadenceAtenQuantizer(ReluPattern1(), quantization_config), + ] + + class CadenceQuantizer(ComposableQuantizer): - def __init__( - self, quantization_config: Optional[QuantizationConfig] = None - ) -> None: - static_qconfig = ( - QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, - None, - ) - if not quantization_config - else quantization_config - ) + """ + Generic CadenceQuantizer. Although it can be used directly, it is typically a base + class for explicitly defined quantizers (like CadenceDefaultQuantizer). + """ - super().__init__( - [ - CadenceAtenQuantizer(AddmmPattern(), static_qconfig), - CadenceAtenQuantizer(BmmPattern(), static_qconfig), - CadenceAtenQuantizer(Conv1dPattern(), static_qconfig), - CadenceAtenQuantizer(Conv2dPattern(), static_qconfig), - CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), - CadenceAtenQuantizer(LinearPattern(), static_qconfig), - CadenceAtenQuantizer(MatmulPattern(), static_qconfig), - CadenceAtenQuantizer(ReluPattern0(), static_qconfig), - CadenceAtenQuantizer(ReluPattern1(), static_qconfig), - ] - ) + def __init__(self, quantizers: List[Quantizer]) -> None: + super().__init__(quantizers) + + +class CadenceDefaultQuantizer(CadenceQuantizer): + """ + Default quantizer for Cadence backend. + """ + + def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None: + if qconfig is None: + qconfig = _default_qconfig + quantizers = get_cadence_default_quantizer_list_with_config(qconfig) + super().__init__(quantizers) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 25a32a5f07..231096c3ab 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -12,7 +12,7 @@ from executorch.backends.cadence.aot.compiler import export_to_edge from executorch.backends.cadence.aot.pass_utils import count_node -from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer from executorch.backends.cadence.aot.remove_ops import ( RemoveAliasCopyOpPass, RemoveCloneOpPass, @@ -465,7 +465,7 @@ def forward(self, x): # Run the standard quant/convert steps, but without fusing # this leaves two redundant quant/dequant pairs to test with - quantizer = CadenceQuantizer() + quantizer = CadenceDefaultQuantizer() model_exp = export_for_training(M(), (inp,)).module() prepared_model = prepare_pt2e(model_exp, quantizer) prepared_model(inp) diff --git a/backends/cadence/fusion_g3/operators/CMakeLists.txt b/backends/cadence/fusion_g3/operators/CMakeLists.txt index 704b4aa741..cac16bddc5 100644 --- a/backends/cadence/fusion_g3/operators/CMakeLists.txt +++ b/backends/cadence/fusion_g3/operators/CMakeLists.txt @@ -36,6 +36,12 @@ set(_aten_ops__srcs "${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_sub.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_div.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_mean.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_slice_copy.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_permute_copy.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_exp.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" @@ -51,6 +57,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp" ) add_library(aten_ops_cadence ${_aten_ops__srcs}) target_link_libraries(aten_ops_cadence PUBLIC executorch) diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp index f40fcc973b..d51fee5338 100644 --- a/backends/cadence/fusion_g3/operators/op_add.cpp +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -39,6 +39,7 @@ Tensor& add_out( ScalarType common_type = executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); +#ifdef OP_ARG_CHECK // Check Common Dtype ET_KERNEL_CHECK( ctx, @@ -62,12 +63,12 @@ Tensor& add_out( torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); +#endif // Compute Dtype ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type); - // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.out"; int kTensorDimensionLimit = 5; @@ -253,6 +254,7 @@ Tensor& add_scalar_out( torch::executor::native::utils::promote_type_with_scalar( a.scalar_type(), b); +#ifdef OP_ARG_CHECK // Check Common Dtype ET_KERNEL_CHECK( ctx, @@ -276,7 +278,7 @@ Tensor& add_scalar_out( executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - +#endif // Compute Dtype ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type); diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp index f0f327c024..74fd96a212 100644 --- a/backends/cadence/fusion_g3/operators/op_cat.cpp +++ b/backends/cadence/fusion_g3/operators/op_cat.cpp @@ -6,13 +6,18 @@ * LICENSE file in the root directory of this source tree. */ +#include +#include + #include #include +#include #include #include +using ::executorch::aten::ArrayRef; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::Error; @@ -23,7 +28,6 @@ using ::executorch::runtime::KernelRuntimeContext; * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ -enum datatype { Ushort = 20, Uint = 23 }; namespace cadence { namespace impl { @@ -32,20 +36,22 @@ namespace native { Tensor& cat_out( KernelRuntimeContext& ctx, - exec_aten::ArrayRef tensors, + ArrayRef tensors, int64_t dim, Tensor& out) { if (dim < 0) { dim += out.dim(); } + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + +#ifdef OP_ARG_CHECK ET_KERNEL_CHECK( ctx, torch::executor::check_cat_args(tensors, dim, out), InvalidArgument, out); - int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; Tensor::SizesType expected_out_size[kTensorDimensionLimit]; size_t expected_out_dim = 0; torch::executor::get_cat_out_target_size( @@ -57,6 +63,20 @@ Tensor& cat_out( out, {expected_out_size, expected_out_dim}) == Error::Ok, InvalidArgument, out); +#endif + // Special handling when all inputs are 1D-empty tensors for aten + // consistency In that case, just return an 1D-empty tensor without checking + // dim + bool all_1d_empty = true; + for (size_t i = 0; i < tensors.size(); ++i) { + if (tensors[i].numel() != 0 || tensors[i].dim() != 1) { + all_1d_empty = false; + break; + } + } + if (all_1d_empty) { + return out; + } const signed char* inp_tensors[tensors.size()]; const int* inp_tensors_shapes[tensors.size()]; @@ -64,7 +84,7 @@ Tensor& cat_out( int inp_shapes_size[tensors.size()]; int temp_sizes[tensors.size()][kTensorDimensionLimit]; - exec_aten::ArrayRef temp_size; + ArrayRef temp_size; for (int i = 0; i < tensors.size(); i++) { inp_tensors[i] = tensors[i].const_data_ptr(); @@ -79,55 +99,23 @@ Tensor& cat_out( signed char* out_data = out.mutable_data_ptr(); - const exec_aten::ArrayRef out_size = out.sizes(); + const ArrayRef out_size = out.sizes(); int out_shapes[kTensorDimensionLimit]; for (int i = 0; i < out_size.size(); i++) // output shapes { out_shapes[i] = out_size[i]; } - if (out.scalar_type() == ScalarType::Int) { - xa_nn_cat( - out_data, - out_shapes, - inp_tensors, - inp_tensors_shapes, - inp_shapes_size[0], - tensors.size(), - (int)dim, - sizeof(int)); - } else if (out.scalar_type() == ScalarType::Short) { - xa_nn_cat( - out_data, - out_shapes, - inp_tensors, - inp_tensors_shapes, - inp_shapes_size[0], - tensors.size(), - (int)dim, - sizeof(short)); - } else if (out.scalar_type() == ScalarType::Char) { - xa_nn_cat( - out_data, - out_shapes, - inp_tensors, - inp_tensors_shapes, - inp_shapes_size[0], - tensors.size(), - (int)dim, - sizeof(char)); - } else if (out.scalar_type() == (ScalarType)Uint) { - xa_nn_cat( - out_data, - out_shapes, - inp_tensors, - inp_tensors_shapes, - inp_shapes_size[0], - tensors.size(), - (int)dim, - sizeof(int)); - } else if (out.scalar_type() == (ScalarType)Ushort) { - xa_nn_cat( + if ((out.scalar_type() == ScalarType::Int) || + (out.scalar_type() == ScalarType::Short) || + (out.scalar_type() == ScalarType::Char) || + (out.scalar_type() == ScalarType::UInt32) || + (out.scalar_type() == ScalarType::UInt16) || + (out.scalar_type() == ScalarType::Byte)) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_cat, out_data, out_shapes, inp_tensors, @@ -135,32 +123,8 @@ Tensor& cat_out( inp_shapes_size[0], tensors.size(), (int)dim, - sizeof(short)); - } else if (out.scalar_type() == ScalarType::Byte) { - xa_nn_cat( - out_data, - out_shapes, - inp_tensors, - inp_tensors_shapes, - inp_shapes_size[0], - tensors.size(), - (int)dim, - sizeof(char)); - + get_element_size(out.scalar_type())); } else { - // Special handling when all inputs are 1D-empty tensors for aten - // consistency In that case, just return an 1D-empty tensor without checking - // dim - bool all_1d_empty = true; - for (size_t i = 0; i < tensors.size(); ++i) { - if (tensors[i].numel() != 0 || tensors[i].dim() != 1) { - all_1d_empty = false; - break; - } - } - if (all_1d_empty) { - return out; - } const size_t outer = executorch::runtime::getLeadingDims(out, dim); const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim); const size_t ninputs = tensors.size(); diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp index ed5b3125ac..cff50f2a90 100644 --- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -6,30 +6,32 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include #include +#include #include #include -using ::executorch::aten::Scalar; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; template -using optional = exec_aten::optional; +using optional = ::executorch::aten::optional; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ -enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 }; +enum datatype { Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -49,14 +51,13 @@ void check_dequantize_per_tensor_args( int64_t quant_min, int64_t quant_max, ScalarType dtype, - exec_aten::optional& out_dtype, + ::executorch::aten::optional& out_dtype, Tensor& out) { ET_CHECK_MSG( input.scalar_type() == ScalarType::Byte || input.scalar_type() == ScalarType::Char || input.scalar_type() == ScalarType::UInt16 || input.scalar_type() == ScalarType::Short || - input.scalar_type() == (ScalarType)Ushort || input.scalar_type() == (ScalarType)Bits4 || input.scalar_type() == (ScalarType)Bits4u || input.scalar_type() == ScalarType::Int, @@ -85,14 +86,16 @@ void check_dequantize_per_tensor_args( } // namespace /* Local function which calls the kernels based on the input datatype */ -void dequantize_impl( +Tensor& dequantize_impl( + KernelRuntimeContext& ctx, Tensor& out, const Tensor& input, float* scale_data, int* zero_point_data, int* axis, - exec_aten::optional out_dtype) { - const exec_aten::ArrayRef input_size = input.sizes(); + ::executorch::aten::optional out_dtype) { + const ::executorch::aten::ArrayRef input_size = + input.sizes(); int kTensorDimensionLimit = 5; @@ -125,7 +128,10 @@ void dequantize_impl( if (is_asym_dequant) { if (input.scalar_type() == ScalarType::Byte) { const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8u_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_asym8u_f32, out_data, input_data, inp_shape, @@ -135,7 +141,10 @@ void dequantize_impl( scale_data); } else if (input.scalar_type() == ScalarType::Char) { const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_asym8_f32, out_data, input_data, inp_shape, @@ -143,9 +152,12 @@ void dequantize_impl( axis, zero_point_data, scale_data); - } else if (input.scalar_type() == (ScalarType)Ushort) { + } else if (input.scalar_type() == ScalarType::UInt16) { const uint16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym16u_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_asym16u_f32, out_data, input_data, inp_shape, @@ -155,7 +167,10 @@ void dequantize_impl( scale_data); } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym16_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_asym16_f32, out_data, input_data, inp_shape, @@ -165,7 +180,10 @@ void dequantize_impl( scale_data); } else if (input.scalar_type() == (ScalarType)Bits4u) { const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym4u_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_asym4u_f32, out_data, input_data, inp_shape, @@ -175,7 +193,10 @@ void dequantize_impl( scale_data); } else if (input.scalar_type() == (ScalarType)Bits4) { const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym4_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_asym4_f32, out_data, input_data, inp_shape, @@ -233,8 +254,9 @@ void dequantize_impl( } } - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + ::executorch::aten::optional<::executorch::aten::ArrayRef> + optional_dim_list{::executorch::aten::ArrayRef{ + dims, size_t(input.dim() - 1)}}; // Actual dequantization logic // input, out are the input and output tensors @@ -318,28 +340,76 @@ void dequantize_impl( } else { if (input.scalar_type() == ScalarType::Byte) { const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym8u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_sym8u_f32, + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data); } else if (input.scalar_type() == ScalarType::Char) { const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym8_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } else if (input.scalar_type() == (ScalarType)Ushort) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_sym8_f32, + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data); + } else if (input.scalar_type() == ScalarType::UInt16) { const uint16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym16u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_sym16u_f32, + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data); } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym16_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_sym16_f32, + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data); } else if (input.scalar_type() == (ScalarType)Bits4u) { const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym4u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_sym4u_f32, + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data); } else if (input.scalar_type() == (ScalarType)Bits4) { const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym4_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_dequantize_sym4_f32, + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data); } else { if (axis == NULL) { // calculate the dequantized output, cast scale to float to match fbgemm @@ -390,8 +460,9 @@ void dequantize_impl( } } - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + ::executorch::aten::optional<::executorch::aten::ArrayRef> + optional_dim_list{::executorch::aten::ArrayRef{ + dims, size_t(input.dim() - 1)}}; // Actual dequantization logic // input, out are the input and output tensors @@ -473,6 +544,7 @@ void dequantize_impl( } } } + return out; } /** @@ -485,14 +557,16 @@ void dequantize_impl( * info. */ Tensor& dequantize_per_tensor_out( + KernelRuntimeContext& context, const Tensor& input, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, ScalarType dtype, - exec_aten::optional out_dtype, + ::executorch::aten::optional out_dtype, Tensor& out) { +#ifdef OP_ARG_CHECK torch::executor::Error err = resize_tensor(out, input.sizes()); ET_CHECK_MSG( err == torch::executor::Error::Ok, @@ -500,24 +574,28 @@ Tensor& dequantize_per_tensor_out( check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); +#endif float scale_data = (float)scale; int zero_point_data = (int)zero_point; - dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype); + dequantize_impl( + context, out, input, &scale_data, &zero_point_data, NULL, out_dtype); return out; } Tensor& dequantize_per_tensor_tensor_args_out( + KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, const Tensor& zero_point, int64_t quant_min, int64_t quant_max, ScalarType dtype, - exec_aten::optional out_dtype, + ::executorch::aten::optional out_dtype, Tensor& out) { +#ifdef OP_ARG_CHECK ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "Expected scale to be Double tensor received: %" PRId8, @@ -534,8 +612,10 @@ Tensor& dequantize_per_tensor_tensor_args_out( zero_point.numel() == 1, "Exepcted zero_point to only have one element received: %zd", ssize_t(zero_point.numel())); +#endif dequantize_per_tensor_out( + context, input, scale.const_data_ptr()[0], zero_point.const_data_ptr()[0], @@ -549,15 +629,24 @@ Tensor& dequantize_per_tensor_tensor_args_out( } Tensor& dequantize_per_channel_out( + KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, - const exec_aten::optional& opt_zero_points, + const ::executorch::aten::optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, ScalarType dtype, - exec_aten::optional out_dtype, + ::executorch::aten::optional out_dtype, Tensor& out) { + if (axis < 0) { + axis += executorch::runtime::nonzero_dim(input); + } + /* if the arguments are passed properly to the operator disable the Macro - + * "OP_ARG_CHECK" if not the case, enable the Macro - "OP_ARG_CHECK", to have + * the checks only in operator level(As there are no checks in kernel). + */ +#ifdef OP_ARG_CHECK torch::executor::Error err = resize_tensor(out, input.sizes()); // normalize axis @@ -567,10 +656,6 @@ Tensor& dequantize_per_channel_out( ssize_t(axis), ssize_t(input.dim())); - if (axis < 0) { - axis += executorch::runtime::nonzero_dim(input); - } - ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); @@ -599,9 +684,9 @@ Tensor& dequantize_per_channel_out( ssize_t(zero_point.numel()), ssize_t(input.size(axis))); } - check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); +#endif int* axis_ptr = (int*)&axis; @@ -622,80 +707,14 @@ Tensor& dequantize_per_channel_out( for (int i = 0; i < scale.numel(); i++) { scale_data[i] = (float)scale_dt[i]; } - dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype); + dequantize_impl( + context, out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype); return out; } -Tensor& dequantize_per_channel_out( - KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const exec_aten::optional& opt_zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional out_dtype, - Tensor& out) { - (void)context; - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( - err == torch::executor::Error::Ok, - "Failed to resize out Tensor in dequantize_per_channel_out"); - return dequantize_per_channel_out( - input, - scale, - opt_zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - out); -} - -Tensor& dequantize_per_tensor_out( - KernelRuntimeContext& context, - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - return dequantize_per_tensor_out( - input, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out.scalar_type(), - out); -} - -Tensor& dequantize_per_tensor_tensor_args_out( - KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional out_dtype, - Tensor& out) { - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - return dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - Tensor& dequantize_per_token_out( + KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, const Tensor& zero_points, @@ -711,18 +730,18 @@ Tensor& dequantize_per_token_out( } // This unfortunate change is needed because we compile op_quantize for aten // mode as well - std::array input_sizes; - input_sizes[0] = static_cast(num_channels); + std::array<::executorch::aten::SizesType, 2> input_sizes; + input_sizes[0] = static_cast<::executorch::aten::SizesType>(num_channels); input_sizes[1] = - static_cast(input.size(input.dim() - 1)); + static_cast<::executorch::aten::SizesType>(input.size(input.dim() - 1)); #ifdef USE_ATEN_LIB Tensor reshaped_input = at::from_blob( input.mutable_data_ptr(), input_sizes, at::TensorOptions(input.scalar_type())); #else - std::array input_dim_order{0, 1}; - std::array input_strides; + std::array<::executorch::aten::DimOrderType, 2> input_dim_order{0, 1}; + std::array<::executorch::aten::StridesType, 2> input_strides; executorch::runtime::dim_order_to_stride_nocheck( input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); void* input_data = input.mutable_data_ptr(); @@ -743,6 +762,7 @@ Tensor& dequantize_per_token_out( #endif return dequantize_per_channel_out( + context, reshaped_input, scale, zero_points, @@ -754,21 +774,6 @@ Tensor& dequantize_per_token_out( out); } -Tensor& dequantize_per_token_out( - KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out) { - (void)context; - return dequantize_per_token_out( - input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); -} - } // namespace native } // namespace G3 } // namespace impl diff --git a/backends/cadence/fusion_g3/operators/op_div.cpp b/backends/cadence/fusion_g3/operators/op_div.cpp new file mode 100644 index 0000000000..1461f643a8 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_div.cpp @@ -0,0 +1,674 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::string_view; +using ::executorch::aten::Tensor; +using ::executorch::runtime::canCast; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +namespace { + +ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { + if (executorch::runtime::isFloatingType(a_type) && + executorch::runtime::isFloatingType(b_type)) { + return executorch::runtime::promoteTypes(a_type, b_type); + } else if (executorch::runtime::isFloatingType(a_type)) { + return a_type; + } else if (executorch::runtime::isFloatingType(b_type)) { + return b_type; + } + return ScalarType::Float; +} + +} // namespace + +Tensor& div_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + +#ifdef OP_ARG_CHECK + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); +#endif + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out"; + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + bool broadcast = 0; + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + max_dim = out.dim() > max_dim ? out.dim() : max_dim; + + bool optimized = 1; + + for (int i = 0; i < max_dim; i++) { + out_shape[i] = 1; + inp1_shape[i] = 1; + inp2_shape[i] = 1; + } + + int offset_out = max_dim - out.dim(); + int offset_inp1 = max_dim - a.dim(); + int offset_inp2 = max_dim - b.dim(); + + for (int i = 0; i < out.dim(); i++) { + out_shape[i + offset_out] = out.size(i); + } + for (int i = 0; i < a.dim(); i++) { + inp1_shape[i + offset_inp1] = a.size(i); + } + for (int i = 0; i < b.dim(); i++) { + inp2_shape[i + offset_inp2] = b.size(i); + } + + /*find broadcast*/ + for (int i = 0; i < out.dim(); i++) { + if (((inp1_shape[i]) != (out_shape[i])) || + ((inp2_shape[i]) != (out_shape[i]))) { + broadcast = 1; + } + } + + if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { + optimized = 0; + } + + if ((compute_type == ScalarType::Int) && (optimized)) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + if (b.numel() == 1) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_32x32_f32, + out_data, + inp1_data, + inp2_data[0], + out.numel()); + } else if (broadcast) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_broadcast_5D_32x32_f32, + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim); + } else { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_32x32_f32, + out_data, + inp1_data, + inp2_data, + out.numel()); + } + } else if ((compute_type == ScalarType::Float) && (optimized)) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + int mode = 0; + + if (b.numel() == 1) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_data[0], + mode, + out.numel()); + } else if (broadcast) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_broadcast_5D_f32xf32_f32, + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + mode, + max_dim); + } else { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_f32xf32_f32, + out_data, + inp1_data, + inp2_data, + mode, + out.numel()); + } + } else { + ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + torch::executor::native::utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name>( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a / val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::FLOATHBF16); + }); + } + + return out; +} + +Tensor& div_out_mode( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + optional mode, + Tensor& out) { + if (!mode.has_value()) { + return div_out(ctx, a, b, out); + } + + auto mode_val = mode.value(); + + // Check mode + ET_KERNEL_CHECK( + ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); + + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + +#ifdef OP_ARG_CHECK + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); +#endif + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out_mode"; + + const bool mode_is_trunc = mode_val == "trunc"; + bool div_by_zero_error = false; + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + bool broadcast = 0; + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + max_dim = out.dim() > max_dim ? out.dim() : max_dim; + + bool optimized = 1; + + for (int i = 0; i < max_dim; i++) { + out_shape[i] = 1; + inp1_shape[i] = 1; + inp2_shape[i] = 1; + } + + int offset_out = max_dim - out.dim(); + int offset_inp1 = max_dim - a.dim(); + int offset_inp2 = max_dim - b.dim(); + + for (int i = 0; i < out.dim(); i++) { + out_shape[i + offset_out] = out.size(i); + } + for (int i = 0; i < a.dim(); i++) { + inp1_shape[i + offset_inp1] = a.size(i); + } + for (int i = 0; i < b.dim(); i++) { + inp2_shape[i + offset_inp2] = b.size(i); + } + + /*find broadcast*/ + for (int i = 0; i < out.dim(); i++) { + if (((inp1_shape[i]) != (out_shape[i])) || + ((inp2_shape[i]) != (out_shape[i]))) { + broadcast = 1; + } + } + + if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { + optimized = 0; + } + + int mode_value = (mode_val == "trunc") ? 1 : 2; + + if ((compute_type == ScalarType::Int) && (optimized)) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + int* const out_data = out.mutable_data_ptr(); + + if (b.numel() == 1) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_32x32_32, + out_data, + inp1_data, + inp2_data[0], + mode_value, + out.numel()); + } else if (broadcast) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_broadcast_5D_32x32_32, + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + mode_value, + max_dim); + } else { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_32x32_32, + out_data, + inp1_data, + inp2_data, + mode_value, + out.numel()); + } + } else if ((compute_type == ScalarType::Float) && (optimized)) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + if (b.numel() == 1) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_data[0], + mode_value, + out.numel()); + } else if (broadcast) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_broadcast_5D_f32xf32_f32, + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + mode_value, + max_dim); + } else { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_f32xf32_f32, + out_data, + inp1_data, + inp2_data, + mode_value, + out.numel()); + } + } else { + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + torch::executor::native::utils:: + apply_bitensor_elementwise_fn( + [mode_is_trunc, &div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + if (executorch::runtime::is_integral_type< + CTYPE_COMPUTE, + /*includeBool=*/true>::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + // We established above that the mode is either trunc or + // floor, so it must be floor. + value = torch::executor::native::utils::floor_divide( + val_a, val_b); + } + return value; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBF16); + }); + } + + ET_KERNEL_CHECK_MSG( + ctx, + !div_by_zero_error, + InvalidArgument, + out, + "Div mode operation encountered integer division by zero"); + + return out; +} + +Tensor& div_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + +#ifdef OP_ARG_CHECK + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok, + InvalidArgument, + out); +#endif + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_out"; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + float* const out_data = out.mutable_data_ptr(); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_32x32_f32, + out_data, + inp1_data, + inp2_val, + out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + float* const out_data = out.mutable_data_ptr(); + + int mode = 0; + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_val, + mode, + out.numel()); + } else { + ScalarType common_type = + executorch::runtime::isFloatingType(a.scalar_type()) + ? a.scalar_type() + : ScalarType::Float; + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes:: + SAME_AS_COMMON); + }); + } + + return out; +} + +Tensor& div_scalar_mode_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + optional mode, + Tensor& out) { + if (!mode.has_value()) { + return div_scalar_out(ctx, a, b, out); + } + + auto mode_val = mode.value(); + + // Check mode + ET_KERNEL_CHECK( + ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); + + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + +#ifdef OP_ARG_CHECK + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check for intergral division by zero + ET_KERNEL_CHECK_MSG( + ctx, + !(executorch::runtime::isIntegralType(common_type, true) && + torch::executor::native::utils::scalar_to(b) == 0), + InvalidArgument, + out, + "Div mode operation encountered integer division by zero"); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok, + InvalidArgument, + out); +#endif + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + const bool mode_is_trunc = mode_val == "trunc"; + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_mode_out"; + + int mode_value = (mode_val == "trunc") ? 1 : 2; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + int* const out_data = out.mutable_data_ptr(); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_32x32_32, + out_data, + inp1_data, + inp2_val, + mode_value, + out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + float* const out_data = out.mutable_data_ptr(); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_div_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_val, + mode_value, + out.numel()); + } else { + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + value = torch::executor::native::utils::floor_divide( + val_a, val_b); + } + return value; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBF16); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_exp.cpp b/backends/cadence/fusion_g3/operators/op_exp.cpp new file mode 100644 index 0000000000..3021a0d4e8 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_exp.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { +#ifdef OP_ARG_CHECK + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_floating_type(out), + InvalidArgument, + out); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + executorch::runtime::resize_tensor(out, in.sizes()) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); +#endif + + if (out.scalar_type() == ScalarType::Float) { + float* const out_data = out.mutable_data_ptr(); + const float* const in_data = in.const_data_ptr(); + + XT_KERNEL_CHECK( + ctx, out, xa_nn_elm_exp_f32_f32, out_data, in_data, out.numel()); + + return out; + } else { + return torch::executor::native::internal:: + unary_ufunc_realhbbf16_to_floathbf16(std::exp, ctx, in, out); + } +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_mean.cpp b/backends/cadence/fusion_g3/operators/op_mean.cpp new file mode 100644 index 0000000000..289baceb12 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_mean.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +using ::executorch::aten::ArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +int prepare_data( + const Tensor& in, + Tensor& out, + optional> dim_list, + int* inp_shape, + int* out_shape, + int* p_axis, + int num_inp_dims, + int num_out_dims) { + for (int i = 0; i < num_inp_dims; i++) { + inp_shape[i] = in.size(i); + } + + for (int i = 0; i < num_out_dims; i++) { + out_shape[i] = out.size(i); + } + + int num_axis_dims = 0; + for (const auto& d : dim_list.value()) { + if (d < 0) { + p_axis[num_axis_dims] = num_inp_dims + d; + num_axis_dims++; + } else { + p_axis[num_axis_dims] = d; + num_axis_dims++; + } + } + + return num_axis_dims; +} + +Tensor& mean_out( + KernelRuntimeContext& ctx, + const Tensor& in, + optional> dim_list, + bool keepdim, + optional dtype, + Tensor& out) { + (void)ctx; + +#ifdef OP_ARG_CHECK + ET_KERNEL_CHECK( + ctx, + torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_default_dim_order(in), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_reduction_out(in, dim_list, keepdim, out) == + Error::Ok, + InvalidArgument, + out); +#endif + + constexpr int kNnlibMaxDim = 5; + + bool optimized = 1; + + if (out.scalar_type() != ScalarType::Float) + optimized = 0; + + if (in.dim() > kNnlibMaxDim) + optimized = 0; + + if (optimized) { + float* __restrict__ p_out = out.mutable_data_ptr(); + const float* __restrict__ p_inp = + (const float* __restrict__)in.const_data_ptr(); + + int num_elm = in.numel(); + + int num_inp_dims = in.dim(); + int num_out_dims = out.dim(); + + int inp_shape[kNnlibMaxDim]; + int out_shape[kNnlibMaxDim]; + int p_axis[kNnlibMaxDim]; + + for (int i = 0; i < kNnlibMaxDim; i++) { + out_shape[i] = 1; + inp_shape[i] = 1; + p_axis[i] = 1; + } + + int num_axis_dims = prepare_data( + in, + out, + dim_list, + inp_shape, + out_shape, + p_axis, + num_inp_dims, + num_out_dims); + + if (num_axis_dims == num_inp_dims) { + num_out_dims = 1; + out_shape[0] = 1; + } + + int inp_shape_max = inp_shape[p_axis[0]]; + for (int i = 1; i < num_axis_dims; i++) { + if (inp_shape[p_axis[i]] > inp_shape_max) { + inp_shape_max = inp_shape[p_axis[i]]; + } + } + + int scratch_size = in.numel() / inp_shape_max; + + executorch::runtime::Result temp_mem = + ctx.allocate_temp(scratch_size * sizeof(float)); + + void* __restrict__ p_scratch_in = (void* __restrict__)(temp_mem.get()); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_mean_f32_f32, + p_out, + out_shape, + num_out_dims, + p_inp, + inp_shape, + num_inp_dims, + p_axis, + num_axis_dims, + p_scratch_in); + } else { + ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] { + ET_SWITCH_FLOATH_TYPES( + out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] { + CTYPE_OUT* out_data = out.mutable_data_ptr(); + const size_t num = + torch::executor::get_reduced_dim_product(in, dim_list); + for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { + CTYPE_OUT sum = 0; + if (in.numel() > 0) { + sum = torch::executor:: + map_reduce_over_dim_list( + [](CTYPE_IN v) { return static_cast(v); }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { + return acc + outv; + }, + in, + dim_list, + out_ix); + } + out_data[out_ix] = sum / static_cast(num); + } + }); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp index 840cb16c7c..93b4c5a992 100644 --- a/backends/cadence/fusion_g3/operators/op_mul.cpp +++ b/backends/cadence/fusion_g3/operators/op_mul.cpp @@ -6,8 +6,11 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include +#include #include #include #include @@ -34,6 +37,7 @@ Tensor& mul_out( ScalarType common_type = executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); +#ifdef OP_ARG_CHECK // Check Common Dtype ET_KERNEL_CHECK( ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); @@ -51,6 +55,7 @@ Tensor& mul_out( torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); +#endif // Compute Dtype ScalarType compute_type = @@ -58,7 +63,6 @@ Tensor& mul_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "mul.out"; - int kTensorDimensionLimit = 5; int inp1_shape[kTensorDimensionLimit]; @@ -111,13 +115,28 @@ Tensor& mul_out( int* const out_data = out.mutable_data_ptr(); if (a.numel() == 1) { - xa_nn_elm_mul_scalar_32x32_32( - out_data, inp2_data, inp1_data[0], out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_scalar_32x32_32, + out_data, + inp2_data, + inp1_data[0], + out.numel()); } else if (b.numel() == 1) { - xa_nn_elm_mul_scalar_32x32_32( - out_data, inp1_data, inp2_data[0], out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_scalar_32x32_32, + out_data, + inp1_data, + inp2_data[0], + out.numel()); } else if (broadcast) { - xa_nn_elm_mul_broadcast_5D_32x32_32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_broadcast_5D_32x32_32, out_data, out_shape, inp1_data, @@ -126,7 +145,14 @@ Tensor& mul_out( inp2_shape, max_dim); } else { - xa_nn_elm_mul_32x32_32(out_data, inp1_data, inp2_data, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_32x32_32, + out_data, + inp1_data, + inp2_data, + out.numel()); } } else if ((compute_type == ScalarType::Float) && (optimized)) { const float* const inp1_data = a.const_data_ptr(); @@ -134,13 +160,28 @@ Tensor& mul_out( float* const out_data = out.mutable_data_ptr(); if (a.numel() == 1) { - xa_nn_elm_mul_scalar_f32xf32_f32( - out_data, inp2_data, inp1_data[0], out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_scalar_f32xf32_f32, + out_data, + inp2_data, + inp1_data[0], + out.numel()); } else if (b.numel() == 1) { - xa_nn_elm_mul_scalar_f32xf32_f32( - out_data, inp1_data, inp2_data[0], out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_data[0], + out.numel()); } else if (broadcast) { - xa_nn_elm_mul_broadcast_5D_f32xf32_f32( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_broadcast_5D_f32xf32_f32, out_data, out_shape, inp1_data, @@ -149,7 +190,14 @@ Tensor& mul_out( inp2_shape, max_dim); } else { - xa_nn_elm_mul_f32xf32_f32(out_data, inp1_data, inp2_data, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_f32xf32_f32, + out_data, + inp1_data, + inp2_data, + out.numel()); } } else { ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { @@ -181,6 +229,7 @@ Tensor& mul_scalar_out( torch::executor::native::utils::promote_type_with_scalar( a.scalar_type(), b); +#ifdef OP_ARG_CHECK // Check Common Dtype ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); @@ -194,29 +243,41 @@ Tensor& mul_scalar_out( // Resize ET_KERNEL_CHECK( ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - +#endif // Compute Dtype ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type); // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "mul.Scalar_out"; - if (compute_type == ScalarType::Int) { const int* const inp1_data = a.const_data_ptr(); int inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); int* const out_data = out.mutable_data_ptr(); - xa_nn_elm_mul_scalar_32x32_32(out_data, inp1_data, inp2_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_scalar_32x32_32, + out_data, + inp1_data, + inp2_val, + out.numel()); } else if (compute_type == ScalarType::Float) { const float* const inp1_data = a.const_data_ptr(); float inp2_val; torch::executor::native::utils::extract_scalar(b, &inp2_val); float* const out_data = out.mutable_data_ptr(); - xa_nn_elm_mul_scalar_f32xf32_f32( - out_data, inp1_data, inp2_val, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_mul_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_val, + out.numel()); } else { ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = @@ -232,7 +293,6 @@ Tensor& mul_scalar_out( SAME_AS_COMMON); }); } - return out; } diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp index a5fbe31eee..9857bbce37 100644 --- a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp +++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp @@ -6,16 +6,20 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include +#include #include #include #include using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::Error; @@ -32,8 +36,8 @@ template void layer_norm( const Tensor& input, IntArrayRef normalized_shape, - const exec_aten::optional& weight, - const exec_aten::optional& bias, + const optional& weight, + const optional& bias, CTYPE eps, Tensor& out, Tensor& mean, @@ -109,8 +113,8 @@ std::tuple native_layer_norm_out( KernelRuntimeContext& ctx, const Tensor& input, IntArrayRef normalized_shape, - const exec_aten::optional& weight, - const exec_aten::optional& bias, + const optional& weight, + const optional& bias, double eps, Tensor& out, Tensor& mean_out, @@ -118,7 +122,9 @@ std::tuple native_layer_norm_out( (void)ctx; std::tuple ret_val(out, mean_out, rstd_out); + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; +#ifdef OP_ARG_CHECK ET_KERNEL_CHECK( ctx, torch::executor::check_layer_norm_args( @@ -156,7 +162,7 @@ std::tuple native_layer_norm_out( InvalidArgument, ret_val); } - int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit]; size_t mean_rstd_ndim = 0; torch::executor::get_layer_norm_out_target_size( @@ -181,6 +187,7 @@ std::tuple native_layer_norm_out( rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok, InvalidArgument, ret_val); +#endif int input_shape[kTensorDimensionLimit]; for (int i = 0; i < input.dim(); i++) { @@ -218,7 +225,10 @@ std::tuple native_layer_norm_out( } } - xa_nn_native_layer_norm_f32_f32( + XT_KERNEL_CHECK( + ctx, + ret_val, + xa_nn_native_layer_norm_f32_f32, out_data, mean_data, rstd_data, diff --git a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp new file mode 100644 index 0000000000..23c2d1e5fb --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include +#include + +using ::executorch::aten::ArrayRef; +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::SizesType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +namespace { + +void increment_coordinate_permuted( + const Tensor& tensor, + size_t* const coordinate, + IntArrayRef dims) { + for (int i = dims.size() - 1; i >= 0; i--) { + size_t d = dims[i] >= 0 ? dims[i] : dims[i] + tensor.dim(); + coordinate[d]++; + if (coordinate[d] == tensor.size(d)) { + coordinate[d] = 0; + } else { + return; + } + } +} + +} // namespace + +Tensor& permute_copy_out( + KernelRuntimeContext& ctx, + const Tensor& in, + IntArrayRef dims, + Tensor& out) { + (void)ctx; + int kTensorDimensionLimit = 5; + /* if the arguments are passed properly to the operator disable the Macro - + * "OP_ARG_CHECK" if not the case, enable the Macro - "OP_ARG_CHECK", to have + * the checks only in operator level(As there are no checks in kernel). + */ +#ifdef OP_ARG_CHECK + ET_KERNEL_CHECK( + ctx, + torch::executor::check_permute_copy_args(in, dims, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); + + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; + size_t expected_out_dim = 0; + torch::executor::get_permute_copy_out_target_size( + in, dims, expected_out_size, &expected_out_dim); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + out, {expected_out_size, expected_out_dim}) == Error::Ok, + InvalidArgument, + out); +#endif + + const ArrayRef in_size = in.sizes(); + const ArrayRef out_size = out.sizes(); + + int inp_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + /* input shapes and output shapes */ + for (auto i = 0; i < in_size.size(); i++) { + inp_shape[i] = in_size[i]; + } + + for (auto i = 0; i < out_size.size(); i++) { + out_shape[i] = out_size[i]; + } + + int permute_vec[in.dim()]; + for (int i = 0; i < in.dim(); i++) { + permute_vec[i] = (int)dims[i]; + } + signed char* out_data = out.mutable_data_ptr(); + const signed char* const inp_data = in.const_data_ptr(); + + if (((out.scalar_type() == ScalarType::Int) || + (out.scalar_type() == ScalarType::Short) || + (out.scalar_type() == ScalarType::Char) || + (out.scalar_type() == ScalarType::UInt32) || + (out.scalar_type() == ScalarType::UInt16) || + (out.scalar_type() == ScalarType::Byte)) && + (in.dim() <= 5)) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_permute, + out_data, + out_shape, + inp_data, + inp_shape, + permute_vec, + in.dim(), + get_element_size(out.scalar_type())); + } else { + const auto in_type = out.scalar_type(); + size_t in_coord[5] = {0}; + size_t trailing_dims_memo[kTensorDimensionLimit]; + executorch::runtime::memoizeTrailingDims(in, trailing_dims_memo); + // in and out must be the same dtype + ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy.out", CTYPE, [&] { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); + + for (size_t i = 0; i < out.numel(); ++i) { + out_data[i] = + in_data[executorch::runtime::coordinateToIndexWithTrailingDimsMemo( + in, in_coord, trailing_dims_memo)]; + increment_coordinate_permuted(in, in_coord, dims); + } + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp index fc206b67cd..8237c3c266 100644 --- a/backends/cadence/fusion_g3/operators/op_quantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -6,17 +6,18 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include #include +#include #include #include -using ::executorch::aten::ArrayRef; -using ::executorch::aten::optional; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::Error; @@ -27,7 +28,7 @@ using ::executorch::runtime::KernelRuntimeContext; * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ -enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 }; +enum datatype { Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -78,9 +79,6 @@ void check_quantize_per_tensor_args( } else if (dtype == ScalarType::Short) { quant_min_lower_bound = std::numeric_limits::min(); quant_max_upper_bound = std::numeric_limits::max(); - } else if (dtype == (ScalarType)Ushort) { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); } else if (dtype == (ScalarType)Bits4u) { quant_min_lower_bound = std::numeric_limits::min(); quant_max_upper_bound = std::numeric_limits::max(); @@ -137,7 +135,8 @@ T quantize_val( } /* Local function which calls the kernels based on the output datatype */ -void quantize_impl( +Tensor& quantize_impl( + KernelRuntimeContext& ctx, Tensor& out, const Tensor& input, float* scale_data, @@ -145,7 +144,8 @@ void quantize_impl( int* axis, int quant_min, int quant_max) { - const ArrayRef input_size = input.sizes(); + const ::executorch::aten::ArrayRef input_size = + input.sizes(); int kTensorDimensionLimit = 5; @@ -179,7 +179,10 @@ void quantize_impl( if (is_asym_quant) { if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym8u( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_asym8u, out_data, input_data, inp_shape, @@ -191,7 +194,11 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym8( + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_asym8, out_data, input_data, inp_shape, @@ -201,9 +208,12 @@ void quantize_impl( zero_point_data, quant_min, quant_max); - } else if (out.scalar_type() == (ScalarType)Ushort) { + } else if (out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym16u( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_asym16u, out_data, input_data, inp_shape, @@ -215,7 +225,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == ScalarType::Short) { int16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym16( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_asym16, out_data, input_data, inp_shape, @@ -227,7 +240,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == (ScalarType)Bits4u) { uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym4u( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_asym4u, out_data, input_data, inp_shape, @@ -239,7 +255,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == (ScalarType)Bits4) { int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym4( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_asym4, out_data, input_data, inp_shape, @@ -304,8 +323,9 @@ void quantize_impl( } } - optional> optional_dim_list{ - ArrayRef{dims, size_t(input.dim() - 1)}}; + ::executorch::aten::optional<::executorch::aten::ArrayRef> + optional_dim_list{::executorch::aten::ArrayRef{ + dims, size_t(input.dim() - 1)}}; // Actual quantization logic // input, out are the input and output tensors @@ -373,7 +393,10 @@ void quantize_impl( } else { if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym8u( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_sym8u, out_data, input_data, inp_shape, @@ -384,7 +407,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym8( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_sym8, out_data, input_data, inp_shape, @@ -393,9 +419,12 @@ void quantize_impl( scale_data, quant_min, quant_max); - } else if (out.scalar_type() == (ScalarType)Ushort) { + } else if (out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym16u( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_sym16u, out_data, input_data, inp_shape, @@ -406,7 +435,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == ScalarType::Short) { int16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym16( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_sym16, out_data, input_data, inp_shape, @@ -417,7 +449,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == (ScalarType)Bits4u) { uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym4u( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_sym4u, out_data, input_data, inp_shape, @@ -428,7 +463,10 @@ void quantize_impl( quant_max); } else if (out.scalar_type() == (ScalarType)Bits4) { int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym4( + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_quantize_f32_sym4, out_data, input_data, inp_shape, @@ -490,8 +528,9 @@ void quantize_impl( } } - optional> optional_dim_list{ - ArrayRef{dims, size_t(input.dim() - 1)}}; + ::executorch::aten::optional<::executorch::aten::ArrayRef> + optional_dim_list{::executorch::aten::ArrayRef{ + dims, size_t(input.dim() - 1)}}; // Actual quantization logic // input, out are the input and output tensors @@ -556,6 +595,7 @@ void quantize_impl( #undef SYM_QUANTIZE_IMPL_CHANNEL } } + return out; } // Quantize the input tensor @@ -568,16 +608,18 @@ Tensor& quantize_per_tensor_out( int64_t quant_max, ScalarType dtype, Tensor& out) { +#ifdef OP_ARG_CHECK Error err = resize_tensor(out, input.sizes()); ET_CHECK_MSG( err == Error::Ok, "Failed to resize out Tensor in quantize_per_tensor_out"); - - // check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); +#endif float scale_data = (float)scale; int zero_point_data = (int)zero_point; quantize_impl( + context, out, input, &scale_data, @@ -606,6 +648,7 @@ Tensor& quantize_per_tensor_tensor_args_out( context.fail(Error::InvalidArgument); return out; } +#ifdef OP_ARG_CHECK ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "Expected scale to be Double tensor received: %" PRId8, @@ -622,6 +665,7 @@ Tensor& quantize_per_tensor_tensor_args_out( zero_point.numel() == 1, "Exepcted zero_point to only have one element received: %zd", ssize_t(zero_point.numel())); +#endif quantize_per_tensor_out( context, @@ -652,6 +696,7 @@ Tensor& quantize_per_tensor_tensor_args_out( } Tensor& quantize_per_channel_out( + KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, const Tensor& zero_point, @@ -660,8 +705,12 @@ Tensor& quantize_per_channel_out( int64_t quant_max, ScalarType dtype, Tensor& out) { - Error err = resize_tensor(out, input.sizes()); + if (axis < 0) { + axis += executorch::runtime::nonzero_dim(input); + } +#ifdef OP_ARG_CHECK + Error err = resize_tensor(out, input.sizes()); // normalize axis ET_CHECK_MSG( executorch::runtime::tensor_has_dim(input, axis), @@ -669,10 +718,6 @@ Tensor& quantize_per_channel_out( ssize_t(axis), ssize_t(input.dim())); - if (axis < 0) { - axis += executorch::runtime::nonzero_dim(input); - } - ET_CHECK_MSG( err == Error::Ok, "Failed to resize out Tensor in quantize_per_channel_out"); @@ -699,7 +744,8 @@ Tensor& quantize_per_channel_out( zero_point.numel(), input.size(axis)); - // check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); +#endif const double* scale_dt = scale.const_data_ptr(); const int64_t* zero_point_dt = zero_point.const_data_ptr(); @@ -715,6 +761,7 @@ Tensor& quantize_per_channel_out( int* axis_ptr = (int*)&axis; quantize_impl( + context, out, input, scale_data, @@ -722,25 +769,12 @@ Tensor& quantize_per_channel_out( axis_ptr, (int)quant_min, (int)quant_max); - return out; -} -Tensor& quantize_per_channel_out( - KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - (void)context; - return quantize_per_channel_out( - input, scale, zero_point, axis, quant_min, quant_max, dtype, out); + return out; } Tensor& quantize_per_token_out( + KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, const Tensor& zero_point, @@ -761,11 +795,11 @@ Tensor& quantize_per_token_out( Tensor reshaped_input = at::from_blob( input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); #else - std::array input_dim_order{0, 1}; - std::array input_sizes; + std::array<::executorch::aten::DimOrderType, 2> input_dim_order{0, 1}; + std::array<::executorch::aten::SizesType, 2> input_sizes; input_sizes[0] = num_tokens; input_sizes[1] = input.size(input.dim() - 1); - std::array input_strides; + std::array<::executorch::aten::StridesType, 2> input_strides; executorch::runtime::dim_order_to_stride_nocheck( input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); void* input_data = input.mutable_data_ptr(); @@ -786,21 +820,15 @@ Tensor& quantize_per_token_out( #endif return quantize_per_channel_out( - reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out); -} - -Tensor& quantize_per_token_out( - KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - (void)context; - return quantize_per_token_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); + context, + reshaped_input, + scale, + zero_point, + 0, + quant_min, + quant_max, + dtype, + out); } } // namespace native diff --git a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp new file mode 100644 index 0000000000..c481cf726b --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include + +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& slice_copy_Tensor_out( + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + ::executorch::aten::optional start_val, + ::executorch::aten::optional end_val, + int64_t step, + Tensor& out) { + (void)ctx; + + if (dim < 0) { + dim += in.dim(); + } + // If user do not set value to end_val, set end to in.size(dim) (largest + // value available) + int64_t end = end_val.has_value() ? end_val.value() : in.size(dim); + // If user do not set value to start_val, set start to 0 (smallest value + // available) + int64_t start = start_val.has_value() ? start_val.value() : 0; + int64_t length = + torch::executor::adjust_slice_indices(in.size(dim), &start, &end, step); + + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + +#ifdef OP_ARG_CHECK + ET_KERNEL_CHECK( + ctx, + torch::executor::check_slice_copy_args(in, dim, step, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + Tensor::SizesType target_sizes[kTensorDimensionLimit]; + size_t target_ndim = 0; + torch::executor::get_slice_copy_out_target_size( + in, dim, length, target_sizes, &target_ndim); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor(out, {target_sizes, target_ndim}) == + Error::Ok, + InvalidArgument, + out); +#endif + + const ::executorch::aten::ArrayRef in_size = in.sizes(); + const ::executorch::aten::ArrayRef out_size = out.sizes(); + + int inp_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + /* input shapes and output shapes */ + for (auto i = 0; i < in_size.size(); i++) { + inp_shape[i] = in_size[i]; + } + + for (auto i = 0; i < out_size.size(); i++) { + out_shape[i] = out_size[i]; + } + + signed char* out_data = out.mutable_data_ptr(); + const signed char* const inp_data = in.const_data_ptr(); + + if ((out.scalar_type() == ScalarType::Int) || + (out.scalar_type() == ScalarType::Short) || + (out.scalar_type() == ScalarType::Char) || + (out.scalar_type() == ScalarType::UInt32) || + (out.scalar_type() == ScalarType::UInt16) || + (out.scalar_type() == ScalarType::Byte)) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_slice, + out_data, + out_shape, + inp_data, + inp_shape, + in.dim(), + (int)start, + (int)(end - 1), + (int)step, + (int)dim, + get_element_size(out.scalar_type())); + } else { + torch::executor::compute_slice(in, dim, start, length, step, out); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp index 9f34348150..ee87ebaf5a 100644 --- a/backends/cadence/fusion_g3/operators/op_softmax.cpp +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -6,10 +6,13 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include +#include #include #include #include @@ -34,6 +37,10 @@ Tensor& _softmax_out( Tensor& out) { (void)ctx; + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + +#ifdef OP_ARG_CHECK ET_KERNEL_CHECK( ctx, torch::executor::check_softmax_args(in, dim, half_to_float, out), @@ -48,9 +55,7 @@ Tensor& _softmax_out( executorch::runtime::tensors_have_same_dim_order(in, out), InvalidArgument, out); - - // Adjust for negative dim - dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; +#endif int inp_shapes[in.dim()]; const ArrayRef in_size = in.sizes(); @@ -62,7 +67,15 @@ Tensor& _softmax_out( const float* const inp_data = in.const_data_ptr(); float* const out_data = out.mutable_data_ptr(); int axis = dim; - xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, in.dim(), &axis); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_softmax_f32_f32, + out_data, + inp_data, + inp_shapes, + in.dim(), + &axis); } else { ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { const CTYPE* const in_data = in.const_data_ptr(); diff --git a/backends/cadence/fusion_g3/operators/op_sub.cpp b/backends/cadence/fusion_g3/operators/op_sub.cpp new file mode 100644 index 0000000000..91782d2dff --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_sub.cpp @@ -0,0 +1,339 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::canCast; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& sub_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); +#ifdef OP_ARG_CHECK + ScalarType alpha_type = + torch::executor::native::utils::get_scalar_dtype(alpha); + + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + canCast(alpha_type, common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); +#endif + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.out"; + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + bool broadcast = 0; + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + max_dim = out.dim() > max_dim ? out.dim() : max_dim; + + bool optimized = 1; + + for (int i = 0; i < max_dim; i++) { + out_shape[i] = 1; + inp1_shape[i] = 1; + inp2_shape[i] = 1; + } + + int offset_out = max_dim - out.dim(); + int offset_inp1 = max_dim - a.dim(); + int offset_inp2 = max_dim - b.dim(); + + for (int i = 0; i < out.dim(); i++) { + out_shape[i + offset_out] = out.size(i); + } + for (int i = 0; i < a.dim(); i++) { + inp1_shape[i + offset_inp1] = a.size(i); + } + for (int i = 0; i < b.dim(); i++) { + inp2_shape[i + offset_inp2] = b.size(i); + } + + /*find broadcast*/ + for (int i = 0; i < out.dim(); i++) { + if (((inp1_shape[i]) != (out_shape[i])) || + ((inp2_shape[i]) != (out_shape[i]))) { + broadcast = 1; + } + } + + if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) { + optimized = 0; + } + + if ((compute_type == ScalarType::Int) && (optimized)) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + int* const out_data = out.mutable_data_ptr(); + + int alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + if (b.numel() == 1) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_scalar_32x32_32, + out_data, + inp1_data, + inp2_data[0], + alpha_val, + out.numel()); + } else if (broadcast) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_broadcast_5D_32x32_32, + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim, + alpha_val); + } else { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_32x32_32, + out_data, + inp1_data, + inp2_data, + alpha_val, + out.numel()); + } + } else if ((compute_type == ScalarType::Float) && (optimized)) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + float alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + if (b.numel() == 1) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_data[0], + alpha_val, + out.numel()); + } else if (broadcast) { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_broadcast_5D_f32xf32_f32, + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim, + alpha_val); + } else { + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_f32xf32_f32, + out_data, + inp1_data, + inp2_data, + alpha_val, + out.numel()); + } + } else { + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = + torch::executor::native::utils::scalar_to(alpha); + torch::executor::native::utils:: + apply_bitensor_elementwise_fn( + [val_alpha]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBF16); + }); + } + + return out; +} + +Tensor& sub_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); +#ifdef OP_ARG_CHECK + ScalarType alpha_type = + torch::executor::native::utils::get_scalar_dtype(alpha); + + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && canCast(alpha_type, common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok, + InvalidArgument, + out); +#endif + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.Scalar_out"; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + int alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + int* const out_data = out.mutable_data_ptr(); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_scalar_32x32_32, + out_data, + inp1_data, + inp2_val, + alpha_val, + out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + float alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + float* const out_data = out.mutable_data_ptr(); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_sub_scalar_f32xf32_f32, + out_data, + inp1_data, + inp2_val, + alpha_val, + out.numel()); + } else { + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = + torch::executor::native::utils::scalar_to(alpha); + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes:: + SAME_AS_COMMON); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/operators.h b/backends/cadence/fusion_g3/operators/operators.h index 9d7f7b9c30..e1c0d08f44 100644 --- a/backends/cadence/fusion_g3/operators/operators.h +++ b/backends/cadence/fusion_g3/operators/operators.h @@ -16,6 +16,13 @@ namespace impl { namespace G3 { namespace native { +::executorch::aten::Tensor& _softmax_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + int64_t dim, + bool half_to_float, + ::executorch::aten::Tensor& out); + ::executorch::aten::Tensor& add_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& a, @@ -30,6 +37,153 @@ ::executorch::aten::Tensor& add_scalar_out( const ::executorch::aten::Scalar& alpha, ::executorch::aten::Tensor& out); +::executorch::aten::Tensor& cat_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + ::executorch::aten::ArrayRef<::executorch::aten::Tensor> tensors, + int64_t dim, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& dequantize_per_channel_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& scale, + const ::executorch::aten::optional<::executorch::aten::Tensor>& + opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::optional<::executorch::aten::ScalarType> out_dtype, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& dequantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::optional<::executorch::aten::ScalarType> out_dtype, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& div_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Tensor& b, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& div_out_mode( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Tensor& b, + ::executorch::aten::optional<::executorch::aten::string_view> mode, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& div_scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Scalar& b, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& div_scalar_mode_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Scalar& b, + ::executorch::aten::optional<::executorch::aten::string_view> mode, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& exp_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& mean_dim_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + ::executorch::aten::optional<::executorch::aten::ArrayRef> + dim_list, + bool keepdim, + ::executorch::aten::optional<::executorch::aten::ScalarType> dtype, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& mul_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Tensor& b, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& mul_scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Scalar& b, + ::executorch::aten::Tensor& out); + +std::tuple< + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&> +native_layer_norm_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef normalized_shape, + const ::executorch::aten::optional<::executorch::aten::Tensor>& weight, + const ::executorch::aten::optional<::executorch::aten::Tensor>& bias, + double eps, + ::executorch::aten::Tensor& out, + ::executorch::aten::Tensor& mean_out, + ::executorch::aten::Tensor& rstd_out); + +::executorch::aten::Tensor& permute_copy_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + ::executorch::aten::IntArrayRef dims, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantize_per_channel_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& scale, + const ::executorch::aten::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& slice_copy_Tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + int64_t dim, + ::executorch::aten::optional start_val, + ::executorch::aten::optional end_val, + int64_t step, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& sub_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Tensor& b, + const ::executorch::aten::Scalar& alpha, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& sub_scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& a, + const ::executorch::aten::Scalar& b, + const ::executorch::aten::Scalar& alpha, + ::executorch::aten::Tensor& out); + } // namespace native } // namespace G3 } // namespace impl diff --git a/backends/cadence/fusion_g3/operators/targets.bzl b/backends/cadence/fusion_g3/operators/targets.bzl index 3e5900e363..e1e7c9a849 100644 --- a/backends/cadence/fusion_g3/operators/targets.bzl +++ b/backends/cadence/fusion_g3/operators/targets.bzl @@ -28,6 +28,7 @@ def define_operator(name: str, deps: list[str] | None = None) -> None: exported_deps = [ ":operators_header", ":xt_macros", + ":xt_utils", ], ) @@ -39,6 +40,12 @@ OPERATORS = [ "native_layer_norm", "quantize", "softmax", + "sub", + "div", + "exp", + "mean", + "slice_copy", + "permute_copy" ] def define_common_targets(): @@ -74,5 +81,17 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "xt_utils", + exported_headers = ["xt_utils.h"], + visibility = [ + "//executorch/backends/cadence/...", + ], + exported_deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + ) + for op in OPERATORS: define_operator(op) diff --git a/backends/cadence/fusion_g3/operators/xt_utils.h b/backends/cadence/fusion_g3/operators/xt_utils.h new file mode 100644 index 0000000000..443d68d060 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/xt_utils.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +using ::executorch::aten::ScalarType; + +inline int get_element_size(ScalarType dtype) { + if ((dtype == ScalarType::Int) || (dtype == ScalarType::UInt32)) { + return sizeof(int); + } else if ((dtype == ScalarType::Short) || (dtype == ScalarType::UInt16)) { + return sizeof(short); + } else if ((dtype == ScalarType::Char) || (dtype == ScalarType::Byte)) { + return sizeof(char); + } + return 0; +} diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 632e67569f..b9efcd7aa6 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -9,10 +9,16 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS, + QCOM_DTYPE, QCOM_ENCODING, QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_QUANT_MIN, QCOM_REQUANTIZE, + QCOM_SCALE, QCOM_SCALES, + QCOM_ZERO_POINT, QCOM_ZERO_POINTS, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -52,45 +58,59 @@ def _expand(self, tensor, dim, axis) -> torch.Tensor: order[axis], order[0] = order[0], order[axis] return tensor.permute(order) - # Find the the last dq node between regular op nodes + # Find the the last dq nodes between regular op nodes # Return dq2 in example below when q1 is given as node parameter: # ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ... - def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node: - if list(node.users)[0].target in q_ops.union(dq_ops): - return self._find_last_dq_node(list(node.users)[0]) - return node + def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node: + if node is None: + return [] + + # If the node is last dq between regular op node, return it in a list + if node.target in dq_ops: + if all(user.target not in q_ops for user in node.users): + return [node] + + last_dq_nodes = [] + for user in list(node.users): + last_dq_nodes.extend(self._find_last_dq_nodes(user)) + + return last_dq_nodes def _annotate_requant(self, n): # Record requant attributes: - # node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... - # We store quant info for dq_ui8 and q_int32 in node1.meta + # node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... + # We store {node2: quant_attr in dq_int32} in node1.meta if n.target in q_ops and n.args[0].target not in dq_ops: - dq_node = self._find_last_dq_node(n) + dq_nodes = self._find_last_dq_nodes(n) q_attrs = get_quant_attrs(self.edge_program, n) - dq_attrs = get_quant_attrs(self.edge_program, dq_node) - - # TODO: Store multiple pairs of requantize attributes when we have an op builder - # that has multiple outputs that requires quant attributes. - if self.skip_advanced_requant: - if q_attrs["dtype"] != dq_attrs["dtype"]: - dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] - n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs - else: - # When dtype is the same but other specs such as scale and offset are different, - # insert requant to improve accuracy. - # Users can turn this feature off if any inference speed drop is observed. - if any( - q_attrs[attr] != dq_attrs[attr] - for attr in [ - "scale", - "zero_point", - "quant_min", - "quant_max", - "dtype", - ] - ): - dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] - n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs + for dq_node in dq_nodes: + dq_attrs = get_quant_attrs(self.edge_program, dq_node) + # TODO: Store multiple pairs of requantize attributes when we have an op builder + # that has multiple outputs that requires quant attributes. + if self.skip_advanced_requant: + if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]: + dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] + user_node = list(dq_node.users)[0] + n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) + n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs + else: + # When dtype is the same but other specs such as scale and offset are different, + # insert requant to improve accuracy. + # Users can turn this feature off if any inference speed drop is observed. + if any( + q_attrs[attr] != dq_attrs[attr] + for attr in [ + QCOM_SCALE, + QCOM_ZERO_POINT, + QCOM_QUANT_MIN, + QCOM_QUANT_MAX, + QCOM_DTYPE, + ] + ): + dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] + user_node = list(dq_node.users)[0] + n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) + n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs # Dequant all the fold_quant parameters back to fp32. # If an operation is not supported by QNN and got fallback, it will expect a fp32 param. @@ -98,14 +118,14 @@ def _dequant_fold_params(self, n, quant_attrs, param): if quant_attrs[QCOM_ENCODING] in [ exir_ops.edge.quantized_decomposed.dequantize_per_channel.default ]: - dim, axis = param.dim(), quant_attrs["axis"] + dim, axis = param.dim(), quant_attrs[QCOM_AXIS] scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis) offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() set_parameter(param, n.args[0], self.edge_program) else: - scale = quant_attrs["scale"] - offset = quant_attrs["zero_point"] + scale = quant_attrs[QCOM_SCALE] + offset = quant_attrs[QCOM_ZERO_POINT] param = param.sub(offset).mul(scale).to(torch.float32).contiguous() set_parameter(param, n.args[0], self.edge_program) diff --git a/backends/qualcomm/_passes/insert_requantize.py b/backends/qualcomm/_passes/insert_requantize.py index 5291edeb9f..11aad02a0c 100644 --- a/backends/qualcomm/_passes/insert_requantize.py +++ b/backends/qualcomm/_passes/insert_requantize.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict +from typing import Dict, List + import torch from executorch.backends.qualcomm.utils.constants import ( @@ -38,6 +41,42 @@ def __init__( super(InsertRequantize, self).__init__() self.edge_program = edge_program + def _make_hashable(self, value): + if isinstance(value, dict): + return tuple(sorted(value.items())) + return value + + def _invert_dict(self, requantize_dict): + inverted_dict = defaultdict(list) + for user_node_name, quant_attr in requantize_dict.items(): + hashable_quant_attr = self._make_hashable(quant_attr) + inverted_dict[hashable_quant_attr].append(user_node_name) + return inverted_dict + + def _insert_to_copy( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.node, + quant_attr: Dict, + user_nodes: List[str], + ): + with graph_module.graph.inserting_after(node): + users = list(node.users.keys()) + inserted_n = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (node,), + ) + inserted_n.meta["val"] = node.meta["val"] + inserted_n.meta[QCOM_QUANT_ATTRS] = quant_attr + + # create node and replace input + if node.meta.get(QCOM_QUANTIZED_IO): + inserted_n.meta[QCOM_QUANTIZED_IO] = node.meta[QCOM_QUANTIZED_IO] + + for user in filter(lambda u: u.name in user_nodes, users): + user.replace_input_with(node, inserted_n) + # TODO: Implement this function when we have an op with # multiple outputs that requires quant attributes. def _multi_output_annotation(self) -> None: @@ -46,21 +85,20 @@ def _multi_output_annotation(self) -> None: def _single_output_annotation( self, gm: torch.fx.GraphModule, n: torch.fx.node ) -> None: - with gm.graph.inserting_after(n): - users = list(n.users.keys()) - inserted_n = gm.graph.create_node( - "call_function", - exir_ops.edge.aten._to_copy.default, - (n,), - ) - - inserted_n.meta["val"] = n.meta["val"] - inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE) - if n.meta.get(QCOM_QUANTIZED_IO): - inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO] + # {user_node_name: quant_attr} + requantize_dict = n.meta.pop(QCOM_REQUANTIZE) + # {quant_attr: user_node_name_list} + group_quant_attr_dict = self._invert_dict(requantize_dict) + # TODO: If users of the node contain output node, + # we replace the node with to_copy op. However, it would + # be problem when the node has multiple to_copy ops + add_output = len(group_quant_attr_dict) == 1 - for user in users: - user.replace_input_with(n, inserted_n) + for hashable_quant_attr, user_nodes in group_quant_attr_dict.items(): + user_nodes_copy = user_nodes.copy() + if add_output: + user_nodes_copy.append("output") + self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy) def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 851b547eb6..66bada86dc 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -14,7 +14,6 @@ QCOM_INSERTED_PERMUTE, QCOM_LAYOUT_CHANGE, QCOM_QUANT_ATTRS, - QCOM_REQUANTIZE, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -133,8 +132,6 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool: # if dimemsion is not kept, we'll have no clue how to do layout transform if len(node.args) < 3 or not node.args[2]: return False - if node.target in self.qdq_opset: - return QCOM_REQUANTIZE in node.meta return node.target in self.layout_agnostic_ops def is_edge_condition(self, node): diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index a81df0d6de..3a97e8d6d6 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -206,21 +206,21 @@ Now, we can start to fill in function body step by step: input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) ``` Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.
The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.
The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.
And yet, there are arguments worth for addressing more: - - **node**: current graph node + - **tensor_source_node**: current graph source node of the tensor + - **target_build_node**: current node to build, which is important for fixed point mixed-precision to work properly - **tensor**: torch tensor emitted by node - **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters - **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN) - - **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly - **node_name**: (optional) tensor name for user to specify - **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object @@ -230,23 +230,24 @@ Now, we can start to fill in function body step by step: weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, + node, weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) bias_node = node.args[3] bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) ``` - The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property. + The logic should be similar and straightforward. Please carefully set arguments `tensor_type` + according to tensors' property. 3. Define parameters: ```python @@ -266,11 +267,11 @@ Now, we can start to fill in function body step by step: ```python output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) ``` Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method. diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 95b147aba5..672cf8b623 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -173,16 +173,19 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict): ) def get_quant_encoding_conf( - self, node: torch.fx.Node, is_input_tensor: bool = False + self, node: torch.fx.Node, target_node: torch.fx.Node ) -> Tuple[Any, Dict]: if not node.meta.get(QCOM_QUANT_ATTRS, None): return ( PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, {}, ) + is_input_tensor = node != target_node quant_attrs = ( - node.meta[QCOM_REQUANTIZE] - if QCOM_REQUANTIZE in node.meta and is_input_tensor + node.meta[QCOM_REQUANTIZE][target_node.name] + if QCOM_REQUANTIZE in node.meta + and is_input_tensor + and target_node.name in node.meta[QCOM_REQUANTIZE] else node.meta[QCOM_QUANT_ATTRS] ) if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING: @@ -282,11 +285,11 @@ def define_custom_tensor_wrapper( def define_tensor( self, - node: torch.fx.Node, + tensor_source_node: torch.fx.Node, + target_build_node: torch.fx.Node, tensor: torch.Tensor, tensor_type: PyQnnWrapper.Qnn_TensorType_t, nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], - is_input_tensor: bool, node_name: str = None, wrapper_idx: int = 0, ) -> PyQnnWrapper.TensorWrapper: @@ -294,28 +297,32 @@ def define_tensor( Covert torch.Tensor to TensorWrapper Args: - node: EdgeIR Node + tensor_source_node: EdgeIR Node + target_build_node: Current node to build tensor: EdgeIR Tensor tensor_type: QNN tensor type nodes_to_wrappers: Set contains edge_graph values(node targets) - is_input_tensor: Whether tensor is a fake input tensor relatively to - the op builder that is calling this function """ if node_name is None: - node_name = node.name + node_name = tensor_source_node.name if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = f"{node.name}_{wrapper_idx}" - if is_graph_input(node, self.edge_program): - tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name - if is_graph_output(node): + tensor_name = f"{tensor_source_node.name}_{wrapper_idx}" + if is_graph_input(tensor_source_node, self.edge_program): + tensor_name = ( + "input_" + + str(self.external_ids[tensor_source_node]) + + "_" + + tensor_name + ) + if is_graph_output(tensor_source_node): tensor_name = "output_" + tensor_name dims = [1] if len(tensor.size()) == 0 else tensor.size() - tensor_type = self.get_tensor_type(node, tensor_type) + tensor_type = self.get_tensor_type(tensor_source_node, tensor_type) quant_encoding, quant_configs = self.get_quant_encoding_conf( - node, is_input_tensor + tensor_source_node, target_build_node ) dtype = self.get_data_type(tensor, quant_configs) if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): @@ -334,7 +341,7 @@ def define_tensor( if quant_configs: tensor = self.get_quant_tensor_value( tensor, - node.meta[QCOM_QUANT_ATTRS], + tensor_source_node.meta[QCOM_QUANT_ATTRS], quant_configs, ) tensor_wrapper = PyQnnWrapper.TensorWrapper( diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py index 1cc5ae7fe6..b5edfd7bb5 100644 --- a/backends/qualcomm/builders/op_add.py +++ b/backends/qualcomm/builders/op_add.py @@ -27,11 +27,11 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) add_output_tensors = [output_tensor_wrapper] @@ -43,10 +43,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, tensor_type, nodes_to_wrappers, - is_input_tensor=True, ) add_input_tensors.append(input_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index 5ad3fc36c9..394d400858 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -32,19 +32,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) # kernel info filter_size = cast(List[int], node.args[1]) diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py index 9ca299e743..aa14df0cb7 100644 --- a/backends/qualcomm/builders/op_batch_norm.py +++ b/backends/qualcomm/builders/op_batch_norm.py @@ -49,10 +49,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) bias_node = node.args[2] @@ -65,20 +65,20 @@ def define_node( self.update_encoding(bias_node, bias_tensor, eps) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps) self.update_encoding(filter_node, filter_tensor, eps) filter_tensor_wrapper = self.define_tensor( filter_node, + node, filter_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) batch_norm_input_tensors = [ @@ -89,11 +89,11 @@ def define_node( output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) batch_norm_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py index 794d66991d..46fbff1cc7 100644 --- a/backends/qualcomm/builders/op_bmm.py +++ b/backends/qualcomm/builders/op_bmm.py @@ -32,20 +32,20 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) bmm_input_tensors.append(input_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) bmm_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index cf18690498..7f16085639 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -36,10 +36,10 @@ def define_node( list_of_tensor_wrappers.append( self.define_tensor( tensor_input, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) ) @@ -52,11 +52,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) # node args[1] might not exist diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py index 883befbccf..19fe14d639 100644 --- a/backends/qualcomm/builders/op_ceil.py +++ b/backends/qualcomm/builders/op_ceil.py @@ -29,19 +29,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) ceil_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py index 0c69a8d333..0f9a9ffa19 100644 --- a/backends/qualcomm/builders/op_clamp.py +++ b/backends/qualcomm/builders/op_clamp.py @@ -31,10 +31,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) # default value of output_min and output_max @@ -51,11 +51,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) clamp_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 30207a0392..9daeab6d4b 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -119,16 +119,16 @@ def _define_conv1d( op_wrapper_list = [] # op_wrapper to return unsqueeze_input_node = node.args[0] input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( - unsqueeze_input_node, + unsqueeze_input_node, node ) unsqueeze_input_tensor = self.get_tensor(unsqueeze_input_node, node) unsqueeze_input_tensor_wrapper = self.define_tensor( unsqueeze_input_node, + node, unsqueeze_input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous() dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs) @@ -165,10 +165,10 @@ def _define_conv1d( filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() filter_tensor_wrapper = self.define_tensor( filter_node, + node, filter_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper] if node.args[2] is not None: @@ -176,10 +176,10 @@ def _define_conv1d( bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) conv_input_tensors.append(bias_tensor_wrapper) @@ -249,11 +249,11 @@ def _define_conv1d( ) squeeze_output_tensor = self.get_tensor(node, node) squeeze_output_tensor_wrapper = self.define_tensor( + node, node, squeeze_output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, node_name=node.name, ) squeeze_op.AddInputTensors([conv_output_tensor_wrapper]) @@ -274,10 +274,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) filter_node = node.args[1] @@ -288,10 +288,10 @@ def define_node( filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() filter_tensor_wrapper = self.define_tensor( filter_node, + node, filter_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] @@ -300,20 +300,20 @@ def define_node( bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) conv_input_tensors.append(bias_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) conv_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_cos.py b/backends/qualcomm/builders/op_cos.py index 98caed10d1..3858a947d9 100644 --- a/backends/qualcomm/builders/op_cos.py +++ b/backends/qualcomm/builders/op_cos.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) cos_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py index e734372098..56c57b4bd5 100644 --- a/backends/qualcomm/builders/op_depth_to_space.py +++ b/backends/qualcomm/builders/op_depth_to_space.py @@ -32,19 +32,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) block_size = [] diff --git a/backends/qualcomm/builders/op_dequantize.py b/backends/qualcomm/builders/op_dequantize.py index f80103b4b8..507ecc4e3e 100644 --- a/backends/qualcomm/builders/op_dequantize.py +++ b/backends/qualcomm/builders/op_dequantize.py @@ -27,20 +27,20 @@ def define_node( input_tensor = self.get_tensor(input_node, node) inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) dequant_input_tensors.append(inp_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) dequant_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py index d3eabb3356..ce3f96abc7 100644 --- a/backends/qualcomm/builders/op_div.py +++ b/backends/qualcomm/builders/op_div.py @@ -27,11 +27,11 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) div_output_tensors = [output_tensor_wrapper] @@ -43,10 +43,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, tensor_type, nodes_to_wrappers, - is_input_tensor=True, ) div_input_tensors.append(input_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 8ae3b64fbf..5b0d160039 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -32,31 +32,31 @@ def define_node( weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, + node, weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=True, ) indices_node = node.args[1] indices_tensor = self.get_tensor(indices_node, node) indices_tensor_wrapper = self.define_tensor( indices_node, + node, indices_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) gather_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py index 3f5c266cdd..c098ed00c9 100644 --- a/backends/qualcomm/builders/op_expand.py +++ b/backends/qualcomm/builders/op_expand.py @@ -31,19 +31,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) sizes = cast(List[int], node.args[1]) diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py index 9bd050cf09..c178740448 100644 --- a/backends/qualcomm/builders/op_gelu.py +++ b/backends/qualcomm/builders/op_gelu.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) gelu_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_group_norm.py b/backends/qualcomm/builders/op_group_norm.py index 44a07e4e58..d498b202d7 100644 --- a/backends/qualcomm/builders/op_group_norm.py +++ b/backends/qualcomm/builders/op_group_norm.py @@ -32,41 +32,41 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) weight_node = node.args[1] weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, + node, weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) bias_node = node.args[2] bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) group = node.args[6] epsilon = node.args[7] output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) group_norm_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_hardsigmoid.py b/backends/qualcomm/builders/op_hardsigmoid.py index 196777d628..1acc08a387 100644 --- a/backends/qualcomm/builders/op_hardsigmoid.py +++ b/backends/qualcomm/builders/op_hardsigmoid.py @@ -32,19 +32,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) hardsigmoid_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py index 36eda8b342..ed28ff95f7 100644 --- a/backends/qualcomm/builders/op_hardswish.py +++ b/backends/qualcomm/builders/op_hardswish.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) hardswish_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py index 8d90385277..68bafaaab8 100644 --- a/backends/qualcomm/builders/op_hardtanh.py +++ b/backends/qualcomm/builders/op_hardtanh.py @@ -32,10 +32,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) # default value of output_min and output_max @@ -50,11 +50,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) hardtanh_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py index 6f8dc558fe..4ddab23aea 100644 --- a/backends/qualcomm/builders/op_index.py +++ b/backends/qualcomm/builders/op_index.py @@ -31,10 +31,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) if len(node.args[1]) > 1: @@ -47,21 +47,21 @@ def define_node( indices_tensor_wrapper = self.define_tensor( indices_node, + node, indices_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) gather_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index af5311dfb2..c317cc0a8b 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -24,10 +24,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) indicies_node = node.args[1] indices_list = [ @@ -45,10 +45,10 @@ def define_node( indices_tensor_wrapper = self.define_tensor( indice_node[0], + node, indices_qnn, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) value_node = node.args[2] @@ -56,18 +56,18 @@ def define_node( value_tensor_wrapper = self.define_tensor( value_node, + node, value_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) index_put_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 635e12d2ee..2006c71648 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -34,10 +34,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) normalized_shapes = node.args[1] @@ -57,31 +57,31 @@ def define_node( weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, + node, weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) bias_node = node.args[3] bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) epsilon = node.args[4] output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) layer_norm_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index e4f16d4473..a16bbd28c9 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -38,10 +38,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) linear_input_tensors.append(input_tensor_wrapper) @@ -59,10 +59,10 @@ def define_node( weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, + node, weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) linear_input_tensors.append(weight_tensor_wrapper) @@ -78,20 +78,20 @@ def define_node( bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) linear_input_tensors.append(bias_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) linear_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index fdd298f988..d395d5eb66 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -32,20 +32,20 @@ def define_node( log_softmax_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) log_softmax_input_tensors = [log_softmax_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) log_softmax_output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) log_softmax_output_tensors = [log_softmax_output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py index 2ea6798e26..c9215d11b4 100644 --- a/backends/qualcomm/builders/op_matmul.py +++ b/backends/qualcomm/builders/op_matmul.py @@ -32,20 +32,20 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) matmul_input_tensors.append(input_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) matmul_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py index 27f14889bf..8d0087eb2c 100644 --- a/backends/qualcomm/builders/op_max_pool2d.py +++ b/backends/qualcomm/builders/op_max_pool2d.py @@ -32,10 +32,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) users = list(node.users.keys()) @@ -51,11 +51,11 @@ def define_node( output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) # kernel info filter_size = cast(List[int], node.args[1]) diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index e60e3e790b..313b24420d 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -32,10 +32,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) # mean dims and keep dims @@ -51,11 +51,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) reduce_mean_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py index db9808bb10..3138d3b8c9 100644 --- a/backends/qualcomm/builders/op_mul.py +++ b/backends/qualcomm/builders/op_mul.py @@ -27,11 +27,11 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) mul_output_tensors = [output_tensor_wrapper] @@ -43,10 +43,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, tensor_type, nodes_to_wrappers, - is_input_tensor=True, ) mul_input_tensors.append(input_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 9ca385ff85..10948859be 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -31,20 +31,20 @@ def define_node( input_tensor = self.get_tensor(input_node, node) pad_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) pad_input_tensors = [pad_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) pad_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py index b4153d458b..cf5b759569 100644 --- a/backends/qualcomm/builders/op_pow.py +++ b/backends/qualcomm/builders/op_pow.py @@ -30,11 +30,11 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) pow_output_tensors = [output_tensor_wrapper] @@ -46,10 +46,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, tensor_type, nodes_to_wrappers, - is_input_tensor=True, ) # scalar input @@ -77,10 +77,10 @@ def define_node( scalar_tensor_wrapper = self.define_tensor( scalar_node, + node, scalar_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) pow_input_tensors = [input_tensor_wrapper, scalar_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py index 5da017b8b7..4057b3d555 100644 --- a/backends/qualcomm/builders/op_prelu.py +++ b/backends/qualcomm/builders/op_prelu.py @@ -38,10 +38,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) prelu_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) if node.target.__name__ == "aten.leaky_relu.default": @@ -89,20 +89,20 @@ def define_node( scalar_tensor_wrapper = self.define_tensor( scalar_node, + node, coeff_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=True, ) prelu_input_tensors = [prelu_inp_tensor_wrapper, scalar_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) prelu_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_quantize.py b/backends/qualcomm/builders/op_quantize.py index 9d53d65571..4921f96b46 100644 --- a/backends/qualcomm/builders/op_quantize.py +++ b/backends/qualcomm/builders/op_quantize.py @@ -28,10 +28,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) quant_input_tensors.append(inp_tensor_wrapper) @@ -43,11 +43,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) quant_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py index 8ddc842e5e..29335797e2 100644 --- a/backends/qualcomm/builders/op_relu.py +++ b/backends/qualcomm/builders/op_relu.py @@ -29,20 +29,20 @@ def define_node( input_tensor = self.get_tensor(input_node, node) relu_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) relu_input_tensors = [relu_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) relu_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py index fde6e7647c..ff4a603fa5 100644 --- a/backends/qualcomm/builders/op_reshape.py +++ b/backends/qualcomm/builders/op_reshape.py @@ -29,18 +29,18 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor_wrapper = self.define_tensor( + node, node, node.meta["val"], PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) reshape_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index 3a5101b12d..d1daa6c1e5 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -36,10 +36,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) # should be a immutable list @@ -60,10 +60,10 @@ def define_node( weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, + node, weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) # Fake node, nn moudle seems to be inconsistant with document @@ -80,10 +80,10 @@ def define_node( bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs bias_tensor_wrapper = self.define_tensor( bias_node, + node, bias_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, ) epsilon = node.args[3] @@ -97,11 +97,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py index 34086dda48..162b485e9e 100644 --- a/backends/qualcomm/builders/op_rsqrt.py +++ b/backends/qualcomm/builders/op_rsqrt.py @@ -29,20 +29,20 @@ def define_node( input_tensor = self.get_tensor(input_node, node) rsqrt_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) rsqrt_input_tensors = [rsqrt_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) rsqrt_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py index fdeec3845e..148888f149 100644 --- a/backends/qualcomm/builders/op_select_copy.py +++ b/backends/qualcomm/builders/op_select_copy.py @@ -32,19 +32,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) dim = cast(int, node.args[1]) diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py index 92ba447d43..ae6e6709c0 100644 --- a/backends/qualcomm/builders/op_sigmoid.py +++ b/backends/qualcomm/builders/op_sigmoid.py @@ -29,20 +29,20 @@ def define_node( input_tensor = self.get_tensor(input_node, node) sigmoid_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) sigmoid_input_tensors = [sigmoid_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) sigmoid_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_sin.py b/backends/qualcomm/builders/op_sin.py index 40e466f59e..89fce6bee9 100644 --- a/backends/qualcomm/builders/op_sin.py +++ b/backends/qualcomm/builders/op_sin.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) sin_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 3a294e3548..8d12e03c0b 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -32,19 +32,19 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, tensor_type, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) dim = cast(int, node.args[1]) diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py index cda40aed45..f6f826e2a4 100644 --- a/backends/qualcomm/builders/op_softmax.py +++ b/backends/qualcomm/builders/op_softmax.py @@ -31,20 +31,20 @@ def define_node( input_tensor = self.get_tensor(input_node, node) softmax_inp_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) softmax_input_tensors = [softmax_inp_tensor_wrapper] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) softmax_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_space_to_depth.py b/backends/qualcomm/builders/op_space_to_depth.py index a9b61c520e..0282cf3f15 100644 --- a/backends/qualcomm/builders/op_space_to_depth.py +++ b/backends/qualcomm/builders/op_space_to_depth.py @@ -32,19 +32,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) block_size = [] diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 58503ff3f8..8e75fd3c10 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -33,10 +33,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) input_tensor_wrappers = [input_tensor_wrapper] @@ -45,11 +45,11 @@ def define_node( for index in range(len(node.meta["val"])): output_tensor = self.get_tensor(node, node, index) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, wrapper_idx=index, ) output_tensor_wrappers.append(output_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py index 7847d00e8b..dc6691460c 100644 --- a/backends/qualcomm/builders/op_sqrt.py +++ b/backends/qualcomm/builders/op_sqrt.py @@ -31,20 +31,20 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) sqrt_input_tensors = [input_tensor_wrapper] out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) sqrt_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py index 00cbda54cf..b828bb7b0b 100644 --- a/backends/qualcomm/builders/op_squeeze.py +++ b/backends/qualcomm/builders/op_squeeze.py @@ -30,19 +30,19 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) squeeze_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py index 66c028253c..954ca9d391 100644 --- a/backends/qualcomm/builders/op_sub.py +++ b/backends/qualcomm/builders/op_sub.py @@ -27,11 +27,11 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) sub_output_tensors = [output_tensor_wrapper] @@ -43,10 +43,10 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, tensor_type, nodes_to_wrappers, - is_input_tensor=True, ) sub_input_tensors.append(input_tensor_wrapper) diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py index abe35c2244..74181f46cb 100644 --- a/backends/qualcomm/builders/op_sum_int_list.py +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -32,10 +32,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) sum_input_tensors = [input_tensor_wrapper] @@ -50,11 +50,11 @@ def define_node( output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) sum_output_tensors = [output_tensor_wrapper] sum_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py index f82ff92f44..ddc9fd2a2a 100644 --- a/backends/qualcomm/builders/op_tanh.py +++ b/backends/qualcomm/builders/op_tanh.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) tanh_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index e17ee2790b..f5cfd4ecf6 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -45,11 +45,12 @@ def is_cast_node(self, node): return True input_tensor = self.get_tensor(input_node, node) - _, inp_qconfs = self.get_quant_encoding_conf(input_node, False) + # Get real quant conf of input node + _, inp_qconfs = self.get_quant_encoding_conf(input_node, input_node) inp_dtype = self.get_data_type(input_tensor, inp_qconfs) output_tensor = self.get_tensor(node, node) - _, out_qconfs = self.get_quant_encoding_conf(node, False) + _, out_qconfs = self.get_quant_encoding_conf(node, node) out_dtype = self.get_data_type(output_tensor, out_qconfs) is_qparam_castable = ( lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon @@ -84,20 +85,20 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) qnn_op = OpCast if self.is_cast_node(node) else OpConvert diff --git a/backends/qualcomm/builders/op_topk.py b/backends/qualcomm/builders/op_topk.py index 84c29925f2..1bbf19c84b 100644 --- a/backends/qualcomm/builders/op_topk.py +++ b/backends/qualcomm/builders/op_topk.py @@ -33,10 +33,10 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=True, ) k = cast(int, node.args[1]) @@ -62,21 +62,21 @@ def define_node( # QNN constraint, topk output_0 requires having the same quant config as input node.meta["quant_attrs"] = input_node.meta.get("quant_attrs") output_val_tensor_wrapper = self.define_tensor( + node, node, output_val_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) # topk output_1 is index, do not quantize it. node.meta.pop("quant_attrs", None) output_index_tensor_wrapper = self.define_tensor( + node, node, output_idx_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, wrapper_idx=1, ) topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py index 20e30da335..d29fc73084 100644 --- a/backends/qualcomm/builders/op_transpose.py +++ b/backends/qualcomm/builders/op_transpose.py @@ -33,10 +33,10 @@ def define_node( input_tensor = self.get_tensor(input_node, permute_node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) # permutation @@ -45,11 +45,11 @@ def define_node( output_tensor = input_tensor.permute(permute_order) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) transpose_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py index 48c9207398..5579012946 100644 --- a/backends/qualcomm/builders/op_unsqueeze.py +++ b/backends/qualcomm/builders/op_unsqueeze.py @@ -30,19 +30,19 @@ def define_node( input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) unsqueeze_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py index 53291786a8..160f15494d 100644 --- a/backends/qualcomm/builders/op_upsample_bilinear2d.py +++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) reisze_bilinear_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py index 75e61d77e5..6b7949716c 100644 --- a/backends/qualcomm/builders/op_upsample_nearest2d.py +++ b/backends/qualcomm/builders/op_upsample_nearest2d.py @@ -30,19 +30,19 @@ def define_node( input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, + node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=True, ) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( + node, node, output_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - is_input_tensor=False, ) reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index c58d0844b4..6167df64b9 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -27,6 +27,12 @@ def annotate_matmul_16a8w( # noqa: C901 ) -> None: """ This function is specific for matmul op 16a8w. + For k, we will tag such as the below, and + for v, we will tag 8a until conv op. + q (16 bits) ──┬─> matmul op (16 bits) + past k (8 bits) ┬─> cat op (8 bits) ─┘ + rotatary add (16 bits) ─┬> cat op (new k) (8 bits) ┘ + rotatary sub (16 bits) ─┘ """ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): @@ -65,6 +71,21 @@ def annotate_cat(node: Node, quantization_config: QuantizationConfig): _annotated=True, ) + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + weight = node.args[1] + input_qspec_map[weight] = quantization_config.weight + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + def annotate_single_in_single_out( node: Node, quantization_config: QuantizationConfig ) -> None: @@ -83,17 +104,37 @@ def annotate_matmul_input1(node: Node): quantization_config_8a8w = get_8a8w_qnn_ptq_config( act_symmetric=True, act_observer=MinMaxObserver ) + quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype="int4", + act_observer=MinMaxObserver, + act_symmetric=True, + ) while isinstance(node, Node) and node.op == "call_function": if node.target in [ torch.ops.aten.permute.default, + torch.ops.aten.squeeze.dim, torch.ops.aten.transpose.int, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, ]: annotate_single_in_single_out(node, quantization_config_8a8w) node = node.args[0] elif node.target == torch.ops.aten.cat.default: annotate_cat(node, quantization_config_8a8w) - node = node.args[0][0] + # For v, we tag 8a until conv op. + # For k, we tag 8a until add or sub op (rotatary embedding). + # The arguments of cat op: (the past kv cache, the new kv cache) + node = node.args[0][1] + elif node.target == torch.ops.aten.conv2d.default: + annotate_conv2d( + node, quantization_config=quantization_config_8a4w_per_channel + ) + break + elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: + break else: + print(f"The node ({node}) is not expected in the input1 of the matmul") node = node.args[0] quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index abe51066ba..a6c551e241 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -249,6 +249,7 @@ def get_ptq_per_channel_quant_config( act_quantization_spec = QuantizationSpec( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, qscheme=torch.per_tensor_symmetric, + ch_axis=0, observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) else: diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 3faa1dfbe9..96aab87826 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -596,6 +596,18 @@ def forward(self, input_pos, k_val): return k_out +class LargeTensorLinear(torch.nn.Module): + def __init__(self): + super().__init__() + hidden_dim = 4096 + self.linear1 = torch.nn.Linear(512, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, 512) + + def forward(self, x): + x1 = self.linear1(x) + self.linear1(x) + return self.linear2(x1) + + class LayerNorm(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4b489ea515..73ca1820f3 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1581,6 +1581,24 @@ def test_qnn_backend_skip_node_op(self): skip_node_op_set={"aten.add.Tensor"}, ) + def test_qnn_backend_spill_fill_buffer_size(self): + module = LargeTensorLinear() # noqa: F405 + sample_input = (torch.randn(1, 256, 512),) + edge_prog = capture_program(module, sample_input) + + backend_options = generate_htp_compiler_spec( + use_fp16=True, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) + max_sf_size = update_spill_fill_size(edge_prog.exported_program) + self.assertNotEqual(0, max_sf_size) + def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -1655,8 +1673,12 @@ def test_qnn_backend_multi_graphs(self): to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) for i, edge_prog in enumerate(edge_progs) ] - prog_mgr = generate_multi_graph_program( - compiler_specs=compiler_specs[0], exported_programs=exported_programs + prog_mgr, _ = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=[ + prog.graph_module.lowered_module_0.processed_bytes + for prog in exported_programs + ], ) for index, module in enumerate(modules): self.verify_output( @@ -2007,6 +2029,25 @@ def calibrator(gm): ).to_executorch() self.verify_output(module, sample_input, exec_prog) + def test_qnn_backend_spill_fill_buffer_size(self): + module = LargeTensorLinear() # noqa: F405 + sample_input = (torch.randn(1, 256, 512),) + module = self.get_qdq_module(module, sample_input) + edge_prog = capture_program(module, sample_input) + + backend_options = generate_htp_compiler_spec( + use_fp16=False, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) + max_sf_size = update_spill_fill_size(edge_prog.exported_program) + self.assertNotEqual(0, max_sf_size) + def test_qnn_backend_graph_level_mixed_precision(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -2120,9 +2161,12 @@ def test_qnn_backend_multi_graphs(self): to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) for i, edge_prog in enumerate(edge_progs) ] - prog_mgr = generate_multi_graph_program( + prog_mgr, _ = generate_multi_graph_program( compiler_specs=compiler_specs[0], - exported_programs=exported_programs, + processed_bytes=[ + prog.graph_module.lowered_module_0.processed_bytes + for prog in exported_programs + ], ) for index, module in enumerate(modules): self.verify_output( diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 2e0ee4f7c6..e13705b3a8 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -6,6 +6,7 @@ import operator import re +import time import warnings from collections import OrderedDict from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple @@ -268,15 +269,17 @@ def set_spec(module, options): options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size set_spec(module, options) + max_sf_size, modules_map = 0, {} if isinstance(exported_program, list): - max_sf_size, modules_map = 0, {} for prog in exported_program: max_sf_buf_size, module_map = get_program_info(prog) max_sf_size = max(max_sf_size, max_sf_buf_size) modules_map.update(module_map) - update_program(max_sf_size, modules_map) else: - update_program(*get_program_info(exported_program)) + max_sf_size, module_map = get_program_info(exported_program) + update_program(max_sf_size, module_map) + + return max_sf_size def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: @@ -740,7 +743,7 @@ def preprocess_binary(ctx_bin, compiler_specs): for k, v in type_map.items(): dtype_map.setdefault(v, k) - qnn_in_order, executorch_in_order, executorch_out_order = [], [], [] + qnn_in_order, executorch_in_order, executorch_out_order = None, None, None if custom_info is not None: # since some context binaries might fail to open on host # if they are compiled with special flags: @@ -748,9 +751,9 @@ def preprocess_binary(ctx_bin, compiler_specs): # use custom information here instead inputs = build_tensor(custom_info["graph_inputs"], dtype_map) outputs = build_tensor(custom_info["graph_outputs"], dtype_map) - qnn_in_order = custom_info["qnn_in_order"] - executorch_in_order = custom_info["executorch_in_order"] - executorch_out_order = custom_info["executorch_out_order"] + qnn_in_order = custom_info.get("qnn_in_order", None) + executorch_in_order = custom_info.get("executorch_in_order", None) + executorch_out_order = custom_info.get("executorch_out_order", None) graph_name = custom_info["graph_name"] else: # get context-binary io tensor info through qnn manager @@ -800,7 +803,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule): def generate_multi_graph_program( compiler_specs: List[CompileSpec], - exported_programs: List[ExportedProgram] = None, + processed_bytes: List[bytes], + input_nodes_dict: List[torch.fx.Node] = None, + output_nodes_dict: List[torch.fx.Node] = None, backend_config: ExecutorchBackendConfig = None, constant_methods: Optional[Dict[str, Any]] = None, ) -> ExecutorchProgramManager: @@ -813,10 +818,6 @@ def generate_multi_graph_program( executorch_in_order, executorch_out_order, ) = ({}, {}, {}, {}, {}) - - processed_bytes = [ - prog.graph_module.lowered_module_0.processed_bytes for prog in exported_programs - ] qnn_mgr = PyQnnManagerAdaptor.QnnManager( generate_qnn_executorch_option(compiler_specs), processed_bytes ) @@ -829,38 +830,36 @@ def generate_multi_graph_program( graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name) # We need to obtain the order of the IOs to correctly map QNN with nn.module - for i, graph_name in enumerate(graph_names): - # input - input_names = [ - node.name - for node in exported_programs[i].graph_module.graph.nodes - if node.op == "placeholder" - ] - qnn_input_names = [wrapper.GetName() for wrapper in graph_inputs[graph_name]] - input_order_list = [] - for input_name in input_names: - # e.g., input_0_tokens_0 - pattern = rf"^input_(\d+)_({input_name})_(\d+)$" - for j in range(len(qnn_input_names)): - if re.match(pattern, qnn_input_names[j]): - input_order_list.append(j) - break - assert ( - len(input_order_list) == len(input_names) == len(qnn_input_names) - ), "Order list length is different from names" - executorch_in_order[graph_name] = input_order_list - qnn_in_order[graph_name] = sorted( - range(len(input_order_list)), key=lambda k: input_order_list[k] - ) - - # output - get_item_list = [ - node - for node in exported_programs[i].graph_module.graph.nodes - if node.op == "output" - ][0].args[0] - output_order_list = [item.args[1] for item in get_item_list] - executorch_out_order[graph_name] = output_order_list + for graph_name in graph_names: + if input_nodes_dict: + # input + input_names = [node.name for node in input_nodes_dict[graph_name]] + qnn_input_names = [ + wrapper.GetName() for wrapper in graph_inputs[graph_name] + ] + # The input of intermideate module including call_function node + # could not be reorder by node name + if len(input_names) == len(qnn_input_names): + input_order_list = [] + for input_name in input_names: + # e.g., input_0_tokens_0 + pattern = rf"^input_(\d+)_({input_name})_(\d+)$" + for j in range(len(qnn_input_names)): + if re.match(pattern, qnn_input_names[j]): + input_order_list.append(j) + break + assert len(input_order_list) == len( + input_names + ), "Order list length is different from names" + executorch_in_order[graph_name] = input_order_list + qnn_in_order[graph_name] = sorted( + range(len(input_order_list)), key=lambda k: input_order_list[k] + ) + if output_nodes_dict: + # output + get_item_list = output_nodes_dict[graph_name][0].args[0] + output_order_list = [item.args[1] for item in get_item_list] + executorch_out_order[graph_name] = output_order_list qnn_mgr.Destroy() @@ -869,15 +868,15 @@ def generate_multi_graph_program( bundle_progs = [ from_context_binary( ctx_path=binary_info, - op_name=f"loader_{graph_name}", + op_name=f"loader_{graph_name}_{int(time.time())}", soc_model=compiler_options.soc_info.soc_model, custom_info={ "graph_inputs": graph_inputs[graph_name], "graph_outputs": graph_outputs[graph_name], "graph_name": graph_name, - "qnn_in_order": qnn_in_order[graph_name], - "executorch_in_order": executorch_in_order[graph_name], - "executorch_out_order": executorch_out_order[graph_name], + "qnn_in_order": qnn_in_order.get(graph_name, None), + "executorch_in_order": executorch_in_order.get(graph_name, None), + "executorch_out_order": executorch_out_order.get(graph_name, None), }, ) for graph_name in graph_names @@ -900,9 +899,101 @@ def generate_multi_graph_program( break edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)) - return edge_prog_mgr.to_executorch( + exec_prog = edge_prog_mgr.to_executorch( + config=backend_config or ExecutorchBackendConfig() + ) + return exec_prog, bundle_progs + + +def generate_composite_llama_program( + graph_names: List[str], + sample_inputs_list: List[Tuple[Any]], + lower_module_dict: Dict[str, List[LoweredBackendModule]], + call_delegate_node_name_dict: Dict[str, List[str]], + call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]], + outputs_dict: Dict[str, List[Tuple[str, int]]], + backend_config: ExecutorchBackendConfig = None, + constant_methods: Optional[Dict[str, Any]] = None, +) -> ExecutorchProgramManager: + class CompositeLlamaModule(torch.nn.Module): + def __init__( + self, + lower_module_list, + call_delegate_node_name_list, + call_delegate_inputs_list, + outputs_list, + ) -> None: + super().__init__() + self.lower_module_list = lower_module_list + self.call_delegate_node_name_list = call_delegate_node_name_list + self.call_delegate_inputs_list = call_delegate_inputs_list + self.outputs_list = outputs_list + + def reorder( + self, + call_delegate_inputs: List[Tuple[str, int | None]], + module_inputs: dict[str, torch.Tensor], + all_ret: dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor]: + ret = [] + for name, index in call_delegate_inputs: + if index is not None: + # Get tensor from previous results + ret.append(all_ret[name][index]) + else: + # Get tensor from the inputs of module + ret.append(module_inputs[name]) + return tuple(ret) + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor]: + all_ret = {} + module_input_dict = { + "tokens": tokens, + "atten_mask": atten_mask, + "input_pos": input_pos, + } + for num, arg in enumerate(args): + module_input_dict[f"args_{num}"] = arg + for lower_module, call_delegate_node_name, call_delegate_inputs in zip( + self.lower_module_list, + self.call_delegate_node_name_list, + self.call_delegate_inputs_list, + ): + inp = self.reorder(call_delegate_inputs, module_input_dict, all_ret) + ret = lower_module(*inp) + all_ret[call_delegate_node_name] = ret + llama_outputs = [] + for output_src_name, index in self.outputs_list: + llama_outputs.append(all_ret[output_src_name][index]) + return tuple(llama_outputs) + + progs_dict = {} + for graph_name, sample_inputs in zip(graph_names, sample_inputs_list): + composite_llama_module = CompositeLlamaModule( + lower_module_dict[graph_name], + call_delegate_node_name_dict[graph_name], + call_delegate_inputs_dict[graph_name], + outputs_dict[graph_name], + ) + prog = torch.export.export(composite_llama_module, sample_inputs) + progs_dict[graph_name] = prog + # leverage ExecutorchProgramManager for generating pte with multi-methods + edge_prog_mgr = to_edge( + progs_dict, + constant_methods=constant_methods, + # do not alter name for custom op + compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False), + ) + exec_prog = edge_prog_mgr.to_executorch( config=backend_config or ExecutorchBackendConfig() ) + return exec_prog def generate_htp_compiler_spec( diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 9517941f36..f425859935 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() { } } +void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { + if (shader.requires_shader_int16) { + if (!adapter_p_->supports_int16_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16); + } + } + if (shader.requires_16bit_storage) { + if (!adapter_p_->supports_16bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE); + } + } + if (shader.requires_8bit_storage) { + if (!adapter_p_->supports_8bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE); + } + } +} + vkapi::DescriptorSet Context::get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, const utils::uvec3& local_workgroup_size, diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 300fd3995d..0c199c24cc 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -185,6 +185,8 @@ class Context final { } } + void check_device_capabilities(const vkapi::ShaderInfo& shader); + vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo&, const utils::uvec3&, diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 7d004547a8..7d3d2d5295 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str: if "codegen-nosub" in input_text: return input_text + # Remove extension requirement so that generated ShaderInfo does not mark it + input_text = input_text.replace( + "#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", "" + ) input_text = input_text.replace("u16vec", "ivec") input_text = input_text.replace("uint16_t", "int") return input_text @@ -791,6 +795,9 @@ class ShaderInfo: weight_storage_type: str = "" bias_storage_type: str = "" register_for: Optional[Tuple[str, List[str]]] = None + requires_shader_int16_ext: bool = False + requires_16bit_storage_ext: bool = False + requires_8bit_storage_ext: bool = False def getName(filePath: str) -> str: @@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: return (matches_list[0], matches_list[1:]) +def isExtensionRequireLine(lineStr: str) -> bool: + extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require" + return re.search(extension_require_id, lineStr) is not None + + typeIdMapping = { r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", @@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.bias_storage_type = getBiasStorageType(line) if isRegisterForLine(line): shader_info.register_for = findRegisterFor(line) + if isExtensionRequireLine(line): + if "GL_EXT_shader_explicit_arithmetic_types_int16" in line: + shader_info.requires_shader_int16_ext = True + if "GL_EXT_shader_16bit_storage" in line: + shader_info.requires_16bit_storage_ext = True + if "GL_EXT_shader_8bit_storage" in line: + shader_info.requires_8bit_storage_ext = True return shader_info @@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) + def to_cpp_str(val: bool): + return "true" if val else "false" + shader_info_args = [ f'"{name}"', f"{name}_bin", str(sizeBytes), shader_info_layouts, tile_size, + to_cpp_str(shader_info.requires_shader_int16_ext), + to_cpp_str(shader_info.requires_16bit_storage_ext), + to_cpp_str(shader_info.requires_8bit_storage_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index a163a0d7ae..63b8798f2c 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; + context->check_device_capabilities(shader_); + std::unique_lock cmd_lock = context->dispatch_lock(); std::array push_constants_data; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index 43a4f7c8dc..23cbb1b652 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -14,7 +14,7 @@ #define op(X, A, B) ${OPERATOR} -#include "indexing_utils.h" +#include "indexing_utils_u16.h" layout(std430) buffer; @@ -35,10 +35,11 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * output at a single output location. */ void main() { + const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x; const ivec3 pos = ivec3( gl_GlobalInvocationID.x % out_limits.x, - (gl_GlobalInvocationID.x / out_limits.x) % out_limits.y, - gl_GlobalInvocationID.x / (out_limits.x * out_limits.y)); + div_by_x % out_limits.y, + div_by_x / out_limits.y); if (any(greaterThanEqual(pos, out_limits))) { return; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index b2ae4953a7..ad4ff245a1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -14,11 +14,13 @@ #define TILE_SIZE ${TILE_SIZE} +#define BATCH_SIZE_X ${BATCH_SIZE_X} + #define BATCH_SIZE_Y ${BATCH_SIZE_Y} #define op(X, A, B) ${OPERATOR} -#include "indexing_utils.h" +#include "indexing_utils_u16.h" layout(std430) buffer; @@ -34,80 +36,88 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. */ void main() { - // y divided up by batch size is used to determine 3d position + // x and y are divided by batch size to determine 3d position // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z - const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y; + const ivec2 out_limits_xy_scaled = (out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y); - u16vec3 pos = u16vec3( - gl_GlobalInvocationID.x % out_limits.x, - ((gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled), - gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled)); + const uint div_by_x = gl_GlobalInvocationID.x / out_limits_xy_scaled.x; + ivec3 pos = ivec3( + gl_GlobalInvocationID.x % out_limits_xy_scaled.x, + div_by_x % out_limits_xy_scaled.y, + div_by_x / out_limits_xy_scaled.y); - // scale pos.y by batch size, because that's the top pixel to be processed - pos.y *= uint16_t(BATCH_SIZE_Y); + // scale pos.xy by batch sizes, because that's the top pixel to be processed + pos.x *= BATCH_SIZE_X; + pos.y *= BATCH_SIZE_Y; // do not process if top pixel does not fit within the output range - if (any(greaterThanEqual(u16vec3(pos.x, pos.y, pos.z), out_limits))) { + if (any(greaterThanEqual(pos, out_limits))) { return; } // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. - const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding); + const ivec2 ipos = pos.xy * stride - padding; // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so any reads from the padding region is skipped. - const u16vec2 start = ipos; - const u16vec2 end = ipos + u16vec2(overlay_region.xy); + const ivec2 start = ipos; + const ivec2 end = ipos + overlay_region.xy; // sum outputs - VEC4_T sum[BATCH_SIZE_Y]; + VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X]; - sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0); - for (int i = 1; i < BATCH_SIZE_Y; i++) { - sum[i] = sum[0]; + sum[0][0] = texelFetch(t_bias, ivec2(pos.z, 0), 0); + for (int y = 0; y < BATCH_SIZE_Y; y++) { + for (int x = 0; x < BATCH_SIZE_X; x++) { + sum[y][x] = sum[0][0]; + } } // array to store input texels - VEC4_T in_texels[TILE_SIZE]; + VEC4_T in_texels[TILE_SIZE + BATCH_SIZE_X - 1]; // array to store kernel data of previous y VEC4_T prev_kernel_line[TILE_SIZE]; - uint16_t kx = uint16_t(0); - for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) { - for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) { - in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0); + int kx = 0; + for (int y = start.y, i = 0; i < TILE_SIZE + BATCH_SIZE_Y - 1; y += dilation.y, i++) { + for (int x = start.x, j = 0; j < TILE_SIZE + BATCH_SIZE_X - 1; x += dilation.x, j++) { + in_texels[j] = texelFetch(t_in, ivec3(x, y, pos.z), 0); } // from 2nd iteration onwards accumulate dot product in 2nd sum // based on kernel line data fetched in previous iteration and input texel from this iteration - if (i > uint16_t(0)) { - for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) { - sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]); + if (i > 0) { + for (int j = 0; j < TILE_SIZE; j++) { + for (int s = 0; s < BATCH_SIZE_X; s++) { + sum[1][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[1][s]); + } } } // accumulate dot product in 1st sum only until tile size - if (i < uint16_t(TILE_SIZE)) { - for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) { - prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0); - sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]); + if (i < TILE_SIZE) { + for (int j = 0; j < TILE_SIZE; j++, kx++) { + prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0); + for (int s = 0; s < BATCH_SIZE_X; s++) { + sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]); + } } } } - for (int i = 0; i < BATCH_SIZE_Y; i++) { - if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) { - continue; + for (int y = 0; y < BATCH_SIZE_Y; y++) { + for (int x = 0; x < BATCH_SIZE_X; x++) { + if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) { + continue; + } + imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max)); } - imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max)); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml index bb197c2c18..9cf6c22c6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml @@ -10,6 +10,7 @@ conv2d_dw_output_tile: NDIM: 3 DTYPE: float TILE_SIZE: 3 + BATCH_SIZE_X: 4 BATCH_SIZE_Y: 2 generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index 23ad912c11..b50a892cad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -16,7 +16,7 @@ #define op(X, A, B) ${OPERATOR} -#include "indexing_utils.h" +#include "indexing_utils_u16.h" layout(std430) buffer; @@ -32,10 +32,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - // shared memory to hold calculated positions, this would reduce register usage thus improving performance. -shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE]; +shared ivec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE]; /* * Computes a 2D pointwise convolution of an NxN output tile. Calculating an @@ -43,13 +41,14 @@ shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroup * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const uvec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE; + const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE; const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; - const u16vec3 gpos = u16vec3( + const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x; + const ivec3 gpos = ivec3( gl_GlobalInvocationID.x % out_limits_scaled.x, - (gl_GlobalInvocationID.x / out_limits_scaled.x) % out_limits_scaled.y, - gl_GlobalInvocationID.x / (out_limits_scaled.x * out_limits_scaled.y)); + div_by_x % out_limits_scaled.y, + div_by_x / out_limits_scaled.y); // Output position for TILE_SIZE = 2 // +--------+--------+ @@ -57,10 +56,10 @@ void main() { // +--------+--------+ // | pos[2] | pos[3] | // +--------+--------+ - u16vec2 pos[TILE_SIZE * TILE_SIZE]; + ivec2 pos[TILE_SIZE * TILE_SIZE]; for (int y = 0, i = 0; y < TILE_SIZE; ++y) { for (int x = 0; x < TILE_SIZE; ++x) { - pos[i] = u16vec2( + pos[i] = ivec2( gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y); pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i]; i++; @@ -69,39 +68,38 @@ void main() { // If the top left position is out of bounds, then this invocation will have // no work to do. - if (any(greaterThanEqual(u16vec3(pos[0], gpos.z), out_limits))) { + if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits))) { return; } // Compute the index of the input texture that needs to be loaded for each // output position. Note that negative indices can be produced indicating that // the top-left element is in a region added by padding. - u16vec2 ipos[TILE_SIZE * TILE_SIZE]; + ivec2 ipos[TILE_SIZE * TILE_SIZE]; for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding); + ipos[i] = pos[i] * stride - padding; } vec4 sum[TILE_SIZE * TILE_SIZE]; - sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0); + sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0); for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) { sum[i] = sum[0]; } int z4 = 0; // Since the kernel is 1x1, we only have to loop over the depth dimension. - for (uint16_t z = uint16_t(0); z < uint16_t(in_group_size); z += uint16_t(4), ++z4) { + for (int z = 0; z < in_group_size; z += 4, ++z4) { // During prepacking, the weight tensor has been permuted so that the // channel (IC) dim is along the x-axis, and the batch (OC) dim is along // the z-axis. - const vec4 ktex_0 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(0, 0)); - const vec4 ktex_1 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(1, 0)); - const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0)); - const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0)); - + const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0)); + const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0)); + const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0)); + const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0)); #pragma unroll for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0); + const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0); // For 2x2 tile size algorithm works as follows. // To explain the calculations below, the contents of one in_tex and the // group of 4 texels loaded from t_kernel are shown: @@ -143,9 +141,9 @@ void main() { } for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex]; - if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) { - imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max)); + const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex]; + if (all(lessThan(ivec3(pos, gpos.z), out_limits))) { + imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max)); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils_u16.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils_u16.h new file mode 100644 index 0000000000..6dc59b6303 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils_u16.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef INDEXING_UTILS_U16_H +#define INDEXING_UTILS_U16_H + +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +u16vec3 idx_to_u16pos_x_wise(uint idx, int size_x, int size_y) { + const uint div_by_x = idx / size_x; + return u16vec3(idx % size_x, div_by_x % size_y, div_by_x / size_y); +} + +#endif // INDEXING_UTILS_U16_H diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 962282348d..9750507a18 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -16,7 +16,8 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} -${define_required_extensions("int8")} +$if STORAGE == "buffer": + ${define_required_extensions("int8")} #include "indexing_utils.h" diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl index dae2f7e3ab..b8d7622f94 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl @@ -16,7 +16,8 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} -${define_required_extensions("int8")} +$if STORAGE == "buffer": + ${define_required_extensions("int8")} $if BATCH_MODE: diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 3519635ac7..64c145fb7e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -299,7 +299,7 @@ utils::uvec3 create_conv2d_global_wg_size( } else if (method == Conv2dMethod::Depthwise) { const utils::uvec3 image_extents = graph.logical_limits_of(out); return { - utils::div_up(image_extents[0u], 1u), + utils::div_up(image_extents[0u], 4u), utils::div_up(image_extents[1u], 2u), image_extents[2u]}; } else { diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 5805d476a3..ec30650ba0 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -256,6 +256,9 @@ std::string Adapter::stringize() const { ss << " deviceType: " << device_type << std::endl; ss << " deviceName: " << properties.deviceName << std::endl; +#define PRINT_BOOL(value, name) \ + ss << " " << std::left << std::setw(36) << #name << value << std::endl; + #define PRINT_PROP(struct, name) \ ss << " " << std::left << std::setw(36) << #name << struct.name \ << std::endl; @@ -298,12 +301,13 @@ std::string Adapter::stringize() const { ss << " }" << std::endl; #endif /* VK_KHR_8bit_storage */ -#ifdef VK_KHR_shader_float16_int8 ss << " Shader 16bit and 8bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16) +#ifdef VK_KHR_shader_float16_int8 PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16); PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8); - ss << " }" << std::endl; #endif /* VK_KHR_shader_float16_int8 */ + ss << " }" << std::endl; const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index e330c1c079..d26fbd8cb2 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg) what_ = oss.str(); } +// +// ShaderNotSupportedError +// + +std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { + switch (result) { + case VulkanExtension::SHADER_INT16: + out << "shaderInt16"; + break; + case VulkanExtension::INT16_STORAGE: + out << "VK_KHR_16bit_storage"; + break; + case VulkanExtension::INT8_STORAGE: + out << "VK_KHR_8bit_storage"; + break; + } + return out; +} + +ShaderNotSupportedError::ShaderNotSupportedError( + std::string shader_name, + VulkanExtension extension) + : shader_name_(std::move(shader_name)), extension_{extension} { + std::ostringstream oss; + oss << "Shader " << shader_name_ << " "; + oss << "not compatible with device. "; + oss << "Missing support for extension or physical device feature: "; + oss << extension_; + what_ = oss.str(); +} + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index ec2f2956a8..a65afb1bcc 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -78,5 +78,26 @@ class Error : public std::exception { } }; +enum class VulkanExtension : uint8_t { + SHADER_INT16, + INT16_STORAGE, + INT8_STORAGE, +}; + +class ShaderNotSupportedError : public std::exception { + public: + ShaderNotSupportedError(std::string shader_name, VulkanExtension extension); + + private: + std::string shader_name_; + VulkanExtension extension_; + std::string what_; + + public: + const char* what() const noexcept override { + return what_.c_str(); + } +}; + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 29774e2f40..e560f37868 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo( const uint32_t* const spirv_bin, const uint32_t size, std::vector layout, - const utils::uvec3 tile_size) + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext) : src_code{ spirv_bin, size, }, kernel_name{std::move(name)}, kernel_layout{std::move(layout)}, - out_tile_size(tile_size) { + out_tile_size(tile_size), + requires_shader_int16(requires_shader_int16_ext), + requires_16bit_storage(requires_16bit_storage_ext), + requires_8bit_storage(requires_8bit_storage_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 1e3b2a799f..d9fec65feb 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -62,6 +62,9 @@ struct ShaderInfo final { // Shader Metadata utils::uvec3 out_tile_size{1u, 1u, 1u}; + bool requires_shader_int16 = false; + bool requires_16bit_storage = false; + bool requires_8bit_storage = false; explicit ShaderInfo(); @@ -70,7 +73,10 @@ struct ShaderInfo final { const uint32_t*, const uint32_t, std::vector, - const utils::uvec3 tile_size); + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 3d9aa6aa80..d7e3896945 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -45,8 +45,13 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{ test_suite_template = """ TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ {create_ref_data} +try {{ {create_and_check_out} }} +catch (const vkcompute::vkapi::ShaderNotSupportedError& e) {{ + GTEST_SKIP() << e.what(); +}} +}} """ diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 1b6f03512b..d6e2f12884 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -679,6 +680,9 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): for i in range(len(model_output)): model = model_output[i] ref = ref_output[i] + assert ( + ref.shape == model.shape + ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" assert torch.allclose( model, ref, diff --git a/build/Utils.cmake b/build/Utils.cmake index 7641e7c7dc..47bc46b18d 100644 --- a/build/Utils.cmake +++ b/build/Utils.cmake @@ -46,12 +46,20 @@ function(executorch_print_configuration_summary) message(STATUS " EXECUTORCH_BUILD_ARM_BAREMETAL : " "${EXECUTORCH_BUILD_ARM_BAREMETAL}" ) + message(STATUS " EXECUTORCH_BUILD_CADENCE : " + "${EXECUTORCH_BUILD_CADENCE}" + ) message( STATUS " EXECUTORCH_BUILD_COREML : ${EXECUTORCH_BUILD_COREML}" ) - message(STATUS " EXECUTORCH_BUILD_KERNELS_CUSTOM : " - "${EXECUTORCH_BUILD_KERNELS_CUSTOM}" + message( + STATUS + " EXECUTORCH_BUILD_CPUINFO : ${EXECUTORCH_BUILD_CPUINFO}" + ) + message( + STATUS + " EXECUTORCH_BUILD_DEVTOOLS : ${EXECUTORCH_BUILD_DEVTOOLS}" ) message(STATUS " EXECUTORCH_BUILD_EXECUTOR_RUNNER : " "${EXECUTORCH_BUILD_EXECUTOR_RUNNER}" @@ -68,7 +76,7 @@ function(executorch_print_configuration_summary) message(STATUS " EXECUTORCH_BUILD_EXTENSION_TENSOR : " "${EXECUTORCH_BUILD_EXTENSION_TENSOR}" ) - message(STATUS " EXECUTORCH_BUILD_EXTENSION_TRAINING : " + message(STATUS " EXECUTORCH_BUILD_EXTENSION_TRAINING : " "${EXECUTORCH_BUILD_EXTENSION_TRAINING}" ) message( @@ -79,22 +87,14 @@ function(executorch_print_configuration_summary) STATUS " EXECUTORCH_BUILD_GFLAGS : ${EXECUTORCH_BUILD_GFLAGS}" ) - message( - STATUS - " EXECUTORCH_BUILD_GTESTS : ${EXECUTORCH_BUILD_GTESTS}" - ) message(STATUS " EXECUTORCH_BUILD_HOST_TARGETS : " "${EXECUTORCH_BUILD_HOST_TARGETS}" ) - message( - STATUS " EXECUTORCH_BUILD_MPS : ${EXECUTORCH_BUILD_MPS}" - ) - message( - STATUS - " EXECUTORCH_BUILD_PYBIND : ${EXECUTORCH_BUILD_PYBIND}" + message(STATUS " EXECUTORCH_BUILD_KERNELS_CUSTOM : " + "${EXECUTORCH_BUILD_KERNELS_CUSTOM}" ) - message( - STATUS " EXECUTORCH_BUILD_QNN : ${EXECUTORCH_BUILD_QNN}" + message(STATUS " EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT : " + "${EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT}" ) message(STATUS " EXECUTORCH_BUILD_KERNELS_OPTIMIZED : " "${EXECUTORCH_BUILD_KERNELS_OPTIMIZED}" @@ -103,27 +103,38 @@ function(executorch_print_configuration_summary) "${EXECUTORCH_BUILD_KERNELS_QUANTIZED}" ) message( - STATUS " EXECUTORCH_BUILD_DEVTOOLS : ${EXECUTORCH_BUILD_DEVTOOLS}" + STATUS " EXECUTORCH_BUILD_MPS : ${EXECUTORCH_BUILD_MPS}" ) message( STATUS - " EXECUTORCH_BUILD_SIZE_TEST : ${EXECUTORCH_BUILD_SIZE_TEST}" + " EXECUTORCH_BUILD_NEURON : ${EXECUTORCH_BUILD_NEURON}" ) message( STATUS - " EXECUTORCH_BUILD_XNNPACK : ${EXECUTORCH_BUILD_XNNPACK}" + " EXECUTORCH_BUILD_PTHREADPOOL : ${EXECUTORCH_BUILD_PTHREADPOOL}" ) message( STATUS - " EXECUTORCH_BUILD_VULKAN : ${EXECUTORCH_BUILD_VULKAN}" + " EXECUTORCH_BUILD_PYBIND : ${EXECUTORCH_BUILD_PYBIND}" + ) + message( + STATUS " EXECUTORCH_BUILD_QNN : ${EXECUTORCH_BUILD_QNN}" ) message( STATUS - " EXECUTORCH_BUILD_PTHREADPOOL : ${EXECUTORCH_BUILD_PTHREADPOOL}" + " EXECUTORCH_BUILD_SIZE_TEST : ${EXECUTORCH_BUILD_SIZE_TEST}" ) message( STATUS - " EXECUTORCH_BUILD_CPUINFO : ${EXECUTORCH_BUILD_CPUINFO}" + " EXECUTORCH_BUILD_TESTS : ${EXECUTORCH_BUILD_TESTS}" + ) + message( + STATUS + " EXECUTORCH_BUILD_VULKAN : ${EXECUTORCH_BUILD_VULKAN}" + ) + message( + STATUS + " EXECUTORCH_BUILD_XNNPACK : ${EXECUTORCH_BUILD_XNNPACK}" ) endfunction() @@ -193,7 +204,10 @@ function(extract_sources sources_file) elseif("${ANDROID_ABI}" STREQUAL "x86_64") set(target_platforms_arg "--target-platforms=shim//:android-x86_64") else() - message(FATAL_ERROR "Unsupported ANDROID_ABI setting ${ANDROID_ABI}. Please add it here!") + message( + FATAL_ERROR + "Unsupported ANDROID_ABI setting ${ANDROID_ABI}. Please add it here!" + ) endif() endif() execute_process( @@ -269,9 +283,12 @@ function(resolve_buck2) endif() endif() - # Update the var in the parent scope. Note that this does not modify our - # local $BUCK2 value. - set(BUCK2 "${buck2}" PARENT_SCOPE) + # Update the var in the parent scope. Note that this does not modify our local + # $BUCK2 value. + set(BUCK2 + "${buck2}" + PARENT_SCOPE + ) # The buck2 daemon can get stuck. Killing it can help. message(STATUS "Killing buck2 daemon") @@ -279,8 +296,7 @@ function(resolve_buck2) # Note that we need to use the local buck2 variable. BUCK2 is only set in # the parent scope, and can still be empty in this scope. COMMAND "${buck2} killall" - WORKING_DIRECTORY ${executorch_root} - COMMAND_ECHO STDOUT + WORKING_DIRECTORY ${executorch_root} COMMAND_ECHO STDOUT ) endfunction() diff --git a/build/resolve_buck.py b/build/resolve_buck.py index d70fe94b41..725a326ea6 100644 --- a/build/resolve_buck.py +++ b/build/resolve_buck.py @@ -11,7 +11,6 @@ import platform import stat import sys -import tempfile import urllib.request from dataclasses import dataclass @@ -68,22 +67,33 @@ class BuckInfo: archive_name="buck2-x86_64-unknown-linux-musl.zst", target_versions=[ # MUSL - "3bbde7daa94987db468d021ad625bc93dc62ba7fcb16945cb09b64aab077f284", + "edae27cfca00053d9c5f7c7be81b6b0d7d07573a50be374ce53a9d8692afa5fc", # GNU - "029b0bcc6f8e399185c1d0f574eba204934722b5", + "10334cb20cb7c321", ], ), ("linux", "aarch64"): BuckInfo( archive_name="buck2-aarch64-unknown-linux-gnu.zst", - target_versions=["49670bee56a7d8a7696409ca6fbf7551d2469787"], + target_versions=[ + # MUSL + "5d7af382acbe0dde70f0e9b0a0bc36deea906077ec1ffe80d3fa280490109051", + # GNU + "08d4382de22fab275978abc7c27c001d7823eb2f", + ], ), ("darwin", "aarch64"): BuckInfo( archive_name="buck2-aarch64-apple-darwin.zst", - target_versions=["99773fe6f7963a72ae5f7b737c02836e"], + target_versions=["f3b7a37732803ed090cd8a37f00cc000"], ), ("darwin", "x86_64"): BuckInfo( archive_name="buck2-x86_64-apple-darwin.zst", - target_versions=["3eb1ae97ea963086866b4d2d9ffa966d"], + target_versions=["9c9a583658d43e82b41f3fc9d369a9b0"], + ), + ("windows", "x86_64"): BuckInfo( + archive_name="buck2-x86_64-pc-windows-msvc.exe.zst", + target_versions=[ + "c7d378f3f307e9590f0b29a5f7f1b21b8e784f4e4bd30a0160b2a69df50d2ee0" + ], ), } @@ -135,6 +145,8 @@ def resolve_buck2(args: argparse.Namespace) -> Union[str, int]: os_family = "linux" elif sys.platform.startswith("darwin"): os_family = "darwin" + elif sys.platform.startswith("win"): + os_family = "windows" platform_key = (os_family, arch) if platform_key not in BUCK_PLATFORM_MAP: @@ -193,12 +205,12 @@ def resolve_buck2(args: argparse.Namespace) -> Union[str, int]: buck2_archive_url = f"https://github.com/facebook/buck2/releases/download/{target_buck_version}/{buck_info.archive_name}" - with tempfile.NamedTemporaryFile() as archive_file: + try: print(f"Downloading buck2 from {buck2_archive_url}...", file=sys.stderr) - urllib.request.urlretrieve(buck2_archive_url, archive_file.name) + archive_file, _ = urllib.request.urlretrieve(buck2_archive_url) # Extract and chmod. - with open(archive_file.name, "rb") as f: + with open(archive_file, "rb") as f: data = f.read() decompressed_bytes = zstd.decompress(data) @@ -207,6 +219,8 @@ def resolve_buck2(args: argparse.Namespace) -> Union[str, int]: file_stat = os.stat(buck2_local_path) os.chmod(buck2_local_path, file_stat.st_mode | stat.S_IEXEC) + finally: + os.remove(archive_file) return buck2_local_path diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index e9c758a60b..2d08f733eb 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -295,6 +295,9 @@ Result prepare_input_tensors( case ScalarType::Float: t.mutable_data_ptr()[j] = 1.; break; + case ScalarType::Char: + t.mutable_data_ptr()[j] = 1; + break; } } } @@ -584,12 +587,18 @@ int main(int argc, const char* argv[]) { i, j, outputs[i].toTensor().const_data_ptr()[j]); - } else { + } else if (t.scalar_type() == ScalarType::Float) { printf( "Output[%d][%d]: %f\n", i, j, outputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Char) { + printf( + "Output[%d][%d]: %d\n", + i, + j, + outputs[i].toTensor().const_data_ptr()[j]); } } #if defined(ET_EVENT_TRACER_ENABLED) diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md index 93ed098fa3..1e03993c94 100644 --- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md @@ -141,7 +141,7 @@ Note: You should only use this step if the prebuilt package doesn't work for you If you need to manually build the package, run the following command in your terminal ``` # Install a compatible version of Buck2 -BUCK2_RELEASE_DATE="2024-05-15" +BUCK2_RELEASE_DATE="2024-12-16" BUCK2_ARCHIVE="buck2-aarch64-apple-darwin.zst" BUCK2=".venv/bin/buck2" diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 813a24a388..28cc502538 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,6 +23,7 @@ import torch from executorch.devtools.etrecord import generate_etrecord +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.builder import DType, LLMEdgeManager @@ -778,6 +779,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [InitializedMutableBufferPass(["cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -792,7 +796,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch( + passes=additional_passes, + ) # Generate ETRecord if edge_manager_copy: @@ -810,7 +816,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch(passes=additional_passes) if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 176b597a94..3410a2af88 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -309,7 +309,6 @@ def forward( seqlen, mask: torch.Tensor, ) -> torch.Tensor: - # TODO(kimishpatel): Move this slicing logic to Attention block so that # SDPA does not have to take input_pos as arg if self.enable_dynamic_shape: diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 21952d8c21..57c36dabf9 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -11,6 +11,7 @@ from executorch.examples.models.llama.llama_transformer import KVCache from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + CustomKVCache, QuantizedCacheType, QuantizedKVCache, ) @@ -19,7 +20,6 @@ class SDPAWithQuantizedKVCacheTest(unittest.TestCase): - def _init_cache(self): self.kv_cache = KVCache( self.max_batch_size, @@ -33,6 +33,19 @@ def _init_cache(self): self.quantized_kv_cache = QuantizedKVCache.from_float( self.kv_cache, QuantizedCacheType.AffineAsymmetric ) + # Need this because first test actually has seq_len > 1 + # and vanilla kvcache cannot handle seq_len > 1, due to + # how input_pos encoding works in the current stack. + # This needs fixing by making sure rest of the stack including + # custom ops or other backends can work with input_pos + # as a sequence of token positions + self.custom_kv_cache = CustomKVCache( + self.max_batch_size, + self.max_seq_len, + self.n_kv_heads, + self.head_dim, + dtype=self.dtype, + ) def _init_kv(self): kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) @@ -59,7 +72,7 @@ def test_simple(self, is_dynamic_shape=False): self.seq_len = 3 self._init_cache() q, k, v = self._init_kv() - self.float_sdpa = SDPACustom(self.kv_cache, self.dim) + self.float_sdpa = SDPACustom(self.custom_kv_cache, self.dim) self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) diff --git a/examples/models/llama/tests/test_simple_sdpa.py b/examples/models/llama/tests/test_simple_sdpa.py index 6e0c391960..2ee8dab8b9 100644 --- a/examples/models/llama/tests/test_simple_sdpa.py +++ b/examples/models/llama/tests/test_simple_sdpa.py @@ -29,7 +29,6 @@ def test_simple_sdpa(self): max_seq_length=max_seq_length, n_heads=n_heads, head_dim=head_dim, - transpose_cache=True, enable_dynamic_shape=False, ) sdpa = SDPA( diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 9a28c94f9c..2b4d709f9b 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -18,13 +18,15 @@ TorchTuneLlamaRunner, ) -from executorch.extension.pybindings.portable_lib import _load_for_executorch +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa @@ -43,7 +45,17 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) - self.model = _load_for_executorch(args.pte) + # Save the loaded model bytes to prevent data from going out of + # scope after the `with` and getting cleaned up by Python's + # garbage collector. + self.model_bytes = None + with open(args.pte, "rb") as f: + self.model_bytes = f.read() + # Need to use _load_for_executorch_from_buffer instead of + # _load_for_executorch because the latter uses MmapDataLoader, + # which doesn't have load_into() implemented, which is needed + # for loading initialized mutable buffers. + self.model = _load_for_executorch_from_buffer(self.model_bytes) self.use_kv_cache = args.kv_cache def forward( diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index dabb07e61c..c4834aee25 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -20,6 +20,9 @@ EmbeddingQuantHandler, get_quant_weight_transform, ) +from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) @@ -101,6 +104,7 @@ def forward(self, input_pos, embeddings): _, quantizers, _ = get_quantizer_and_quant_params(args) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: + source_transforms.append(replace_kv_cache_with_custom_kv_cache) source_transforms.append(replace_sdpa_with_custom_op) source_transforms.append(quant_transform) manager = ( diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index a24249d953..68a9e59e0c 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -14,6 +14,9 @@ import torch from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) @@ -62,6 +65,7 @@ def __init__( self.text_model = Transformer(self.text_model_args) # use custom op for SDPA. if use_sdpa_with_kv_cache_op: + self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model) self.text_model = replace_sdpa_with_custom_op(self.text_model) # load state dict self.text_model.load_state_dict( diff --git a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py index 6894537b82..d1b618ed07 100755 --- a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py @@ -11,9 +11,8 @@ import torch import torch.nn as nn - +import torch.nn.functional as F from executorch.examples.models.llama.llama_transformer import ( - FeedForward, ModelArgs, precompute_freqs_cis, ) @@ -58,37 +57,44 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): def prepare_sha(self): self.wq_sha = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) for _ in range(self.n_heads) ] ) self.wk_sha = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) for _ in range(self.n_kv_heads) ] ) self.wv_sha = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) for _ in range(self.n_kv_heads) ] ) + self.wo_sha = nn.Conv2d(self.n_heads * self.head_dim, self.dim, 1, bias=False) self.forward_mha = self.forward self.forward = self.forward_sha - for i in range(self.n_heads): self.wq_sha[i].weight.data.copy_( - self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + self.wq.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] ) for i in range(self.n_kv_heads): self.wk_sha[i].weight.data.copy_( - self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + self.wk.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] ) self.wv_sha[i].weight.data.copy_( - self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + self.wv.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] ) + self.wo_sha.weight.data.copy_(self.wo.weight[:, :, None, None]) def forward_sha( self, @@ -99,9 +105,22 @@ def forward_sha( k_caches: Optional[List[torch.Tensor]] = None, v_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q = [wq_sha(hidden_states) for wq_sha in self.wq_sha] - k = [wk_sha(hidden_states) for wk_sha in self.wk_sha] - v = [wv_sha(hidden_states) for wv_sha in self.wv_sha] + bsz, seq_len, _ = hidden_states.shape + hidden_states = torch.reshape( + hidden_states, (bsz, seq_len, 1, self.dim) + ).transpose(1, 3) + q = [ + wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wq_sha in self.wq_sha + ] + k = [ + wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wk_sha in self.wk_sha + ] + v = [ + wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wv_sha in self.wv_sha + ] for i in range(len(q)): q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) for i in range(len(k)): @@ -129,7 +148,11 @@ def forward_sha( output_y.append(y) y = torch.concat(output_y, dim=-1) - y = self.wo(y) + y = y.reshape(bsz, seq_len, 1, -1) + y = y.transpose(1, 3) + y = self.wo_sha(y) + y = y.transpose(1, 3) + y = y.reshape(bsz, seq_len, -1) if self.output_new_cache_only: if k_caches and v_caches: @@ -148,12 +171,12 @@ def forward( k_caches: List[torch.Tensor], v_caches: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, seqlen, _ = hidden_states.shape + bsz, seq_len, _ = hidden_states.shape q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) - q = q.view(bsz, seqlen, self.n_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim) q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) @@ -203,6 +226,45 @@ def forward( return y, output_kh, output_vh +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.hidden_dim is not None + self.hidden_dim: int = args.hidden_dim + self.dim: int = args.dim + self.w1 = nn.Linear(self.dim, self.hidden_dim, bias=False) + self.w2 = nn.Linear(self.hidden_dim, self.dim, bias=False) + self.w3 = nn.Linear(self.dim, self.hidden_dim, bias=False) + + def prepare_feedfoward_conv(self): + self.w1_conv = nn.Conv2d(self.dim, self.hidden_dim, 1, bias=False) + self.w2_conv = nn.Conv2d(self.hidden_dim, self.dim, 1, bias=False) + self.w3_conv = nn.Conv2d(self.dim, self.hidden_dim, 1, bias=False) + + self.forward_no_conv = self.forward + self.forward = self.forward_feedfoward_conv + + self.w1_conv.weight.data.copy_(self.w1.weight[:, :, None, None]) + self.w2_conv.weight.data.copy_(self.w2.weight[:, :, None, None]) + self.w3_conv.weight.data.copy_(self.w3.weight[:, :, None, None]) + + del self.w1 + del self.w2 + del self.w3 + + def forward_feedfoward_conv(self, x): + bsz, _, _ = x.size() + x = torch.reshape(x, (bsz, -1, self.dim, 1)) + x = x.transpose(1, 2) # Transpose right before and after Conv + x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x)) + x = x.transpose(1, 2) + x = torch.reshape(x, (bsz, -1, self.dim)) + return x + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() @@ -268,6 +330,22 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) + def prepare_output_conv(self): + def forward_output_conv(x): + bsz, _, _ = x.size() + x = torch.reshape(x, (bsz, -1, 1, self.dim)) + x = x.transpose(1, 3) # Transpose right before and after Conv + x = self.output_conv(x) + x = x.transpose(1, 3) + x = torch.reshape(x, (bsz, -1, self.vocab_size)) + return x + + self.output_conv = nn.Conv2d(self.dim, self.vocab_size, 1, bias=False) + self.output_conv.weight.data.copy_(self.output.weight[:, :, None, None]) + + del self.output + self.output = forward_output_conv + def forward( self, tokens: torch.Tensor, diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 13fb99a420..a18690e941 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -18,7 +18,6 @@ from multiprocessing.connection import Client import torch -from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner @@ -32,10 +31,12 @@ from executorch.backends.qualcomm.utils.utils import ( capture_program, convert_linear_to_conv2d, + generate_composite_llama_program, generate_htp_compiler_spec, generate_multi_graph_program, generate_qnn_executorch_compiler_spec, get_soc_to_chipset_map, + update_spill_fill_size, ) from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import ( LlamaModel, @@ -78,7 +79,9 @@ def _kv_calibrate( # TODO: change criteria & support batch inputs if necessary pos = torch.tensor(0, dtype=torch.int32) max_cache_len = max_seq_len - 1 - token_list = sp_model.encode(user_prompts, bos=True, eos=False) + token_list = sp_model.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) with torch.no_grad(): while token_list[-1] != sp_model.eos_id and pos < max_cache_len: @@ -118,28 +121,30 @@ def _prefill_calibrate( max_cache_len = max_seq_len - 1 # TODO: change criteria & support batch inputs if necessary - token_list = sp_model.encode(user_prompts, bos=True, eos=False) - token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1) - last_prompt_pos = token_list.numel() - if last_prompt_pos < max_cache_len: - token_list = torch.cat( - [ - token_list, - torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32), - ], - dim=1, - ) - else: - token_list = token_list[:, :max_cache_len] + token_list = sp_model.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + pos = len(token_list) with torch.no_grad(): - logits, new_k_caches, new_v_caches = module( - token_list, - atten_mask, - ) - predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()] + while token_list[-1] != sp_model.eos_id and pos < max_cache_len: + tmp_token_list = torch.tensor(token_list).reshape(1, -1) + if pos < max_cache_len: + tmp_token_list = torch.cat( + [ + tmp_token_list, + torch.zeros((1, max_cache_len - pos), dtype=torch.int32), + ], + dim=1, + ) + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + atten_mask, + ) + token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item()) + pos += 1 - print(f"calibration data:\n{sp_model.decode(predict)}") + print(f"calibration data:\n{sp_model.decode(token_list)}") def calibrate( @@ -186,41 +191,94 @@ def __init__(self, llama_model, pte_filename) -> None: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) - def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): + def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): if not self.has_quant_io: return # shape of k caches and v caches - input_cache_shape = { + kv_cache_shape = { + # single head, kv mode input (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + # single head, kv mode output + (self.llama_meta["get_head_dim"], 1), + (1, self.llama_meta["get_head_dim"]), + # single head, bert mode + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"] - 1), + (self.llama_meta["get_max_seq_len"] - 1, self.llama_meta["get_head_dim"]), } + io_shape = { + # kv mode + ( + self.llama_meta["get_max_batch_size"], + 1, + self.llama_meta["get_vocab_size"], + ), + # bert mode + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_max_seq_len"] - 1, + self.llama_meta["get_vocab_size"], + ), + } + + atten_mask_shape = { + # kv mode + (self.llama_meta["get_max_batch_size"], self.llama_meta["get_max_seq_len"]), + # bert mode + ( + self.llama_meta["get_max_seq_len"] - 1, + self.llama_meta["get_max_seq_len"] - 1, + ), + } + + freq_shape = { + # kv mode + (1, self.llama_meta["get_head_dim"] // 2), + # bert mode + ( + self.llama_meta["get_max_seq_len"] - 1, + self.llama_meta["get_head_dim"] // 2, + ), + } + + freq_op = { + # kv mode + exir_ops.edge.aten.select.int, + # bert mode + exir_ops.edge.aten.slice_copy.Tensor, + } + for n in gm.graph.nodes: - if ( - n.op == "placeholder" - and len(users := list(n.users)) == 1 - and users[0].meta["val"].size()[-2:] in input_cache_shape - ): - n.meta[QCOM_QUANTIZED_IO] = kv_type + if n.op == "placeholder": + if ( + len(users := list(n.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] + elif n.meta["val"].size() in io_shape: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + elif n.meta["val"].size() in atten_mask_shape: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] elif n.op == "output": for a in n.args[0]: - # single head, kv mode - if ( - a.meta["val"].flatten().size()[0] - == self.llama_meta["get_head_dim"] - ): - a.meta[QCOM_QUANTIZED_IO] = kv_type - # single head, prefill mode - elif a.meta["val"].flatten().size()[0] == self.llama_meta[ - "get_head_dim" - ] * (self.llama_meta["get_max_seq_len"] - 1): - a.meta[QCOM_QUANTIZED_IO] = kv_type + if a.meta["val"].size()[-2:] in kv_cache_shape: + a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] + elif a.meta["val"].size() in io_shape: + a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + quant_attrs = a.meta["quant_attrs"] # Tag sharding io if exir_ops.edge.llama.fallback.default in [ u.target for u in list(n.users.keys()) ] + [n.target]: - n.meta[QCOM_QUANTIZED_IO] = sharding_type + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + + # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding + if n.target in freq_op and n.meta["val"].size() in freq_shape: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + + return quant_attrs def quantize(self, quant_dtype, args, custom_annotations=()): self.quant_dtype = quant_dtype @@ -254,16 +312,12 @@ def quantize(self, quant_dtype, args, custom_annotations=()): def lowering_modules( self, work_space, - kv_type=torch.uint8, - sharding_type=torch.uint16, + fixed_point_type, use_fp16=False, soc_model=QcomChipset.SM8650, num_sharding=0, ): executorch_config = ExecutorchBackendConfig( - passes=[ - BuildQuantIo(), - ], # For shared buffer, user must pass the memory address # which is allocated by RPC memory to executor runner. # Therefore, won't want to pre-allocate @@ -276,7 +330,9 @@ def lowering_modules( ) with torch.no_grad(): # backend option - backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=num_sharding > 0 + ) compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=soc_model, backend_options=backend_options, @@ -297,10 +353,9 @@ def lowering_modules( shares=num_sharding, ) - self._tag_kv_ios( + self.quant_attrs = self._tag_ios( edge_prog.exported_program.graph_module, - kv_type=kv_type, - sharding_type=sharding_type, + fixed_point_type=fixed_point_type, ) edge_prog_mgr = EdgeProgramManager( edge_programs={"forward": edge_prog.exported_program}, @@ -308,13 +363,18 @@ def lowering_modules( compile_config=EdgeCompileConfig(_check_ir_validity=False), ) edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) + if num_sharding > 0: + update_spill_fill_size(edge_prog_mgr.exported_program()) exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) - with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: + with open(f"{work_space}/{pte_filename}.pte", "wb") as file: exec_prog_mgr.write_to_file(file) def get_example_inputs(self, use_kv_cache=True): return self.llama_model.get_example_inputs(use_kv_cache) + def get_quant_attrs(self): + return self.quant_attrs + def compile(args, pte_filename): os.makedirs(args.artifact, exist_ok=True) @@ -371,24 +431,25 @@ def compile(args, pte_filename): for layer in llama_instance.layers: if getattr(layer.attention, "prepare_sha", None): layer.attention.prepare_sha() - - use_fp16 = False - if args.ptq != None: - kv_type = torch.uint8 + if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): + layer.feed_forward.prepare_feedfoward_conv() + + use_fp16 = True + fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} + if args.ptq: + use_fp16 = False + fixed_point_type["kv_type"] = torch.uint8 if args.ptq == "8a8w": - sharding_type = torch.uint8 + fixed_point_type["io_type"] = torch.uint8 elif args.ptq == "16a4w": - sharding_type = torch.uint16 + fixed_point_type["io_type"] = torch.uint16 else: assert args.ptq in [ "8a8w", "16a4w", ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") - else: - use_fp16 = True - kv_type = torch.float32 - sharding_type = torch.float32 + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" if args.dtype_override is not None: @@ -404,7 +465,7 @@ def compile(args, pte_filename): llama_instance_list[i].eval(), pte_filename ) - if args.ptq != None: + if args.ptq: start_quantize_ts = time.time() for llama_instance in llama_instance_list: llama_instance.quantize( @@ -421,16 +482,17 @@ def compile(args, pte_filename): logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") start_lowering_ts = time.time() + quant_attrs = None if len(llama_instance_list) == 1: llama_instance_list[0].lowering_modules( args.artifact, - kv_type=kv_type, - sharding_type=sharding_type, + fixed_point_type, use_fp16=use_fp16, soc_model=get_soc_to_chipset_map()[args.model], num_sharding=args.num_sharding, ) + quant_attrs = llama_instance_list[0].get_quant_attrs() else: sample_inputs_list = [ llama_instace.inputs for llama_instace in llama_instance_list @@ -451,12 +513,13 @@ def compile(args, pte_filename): ) for i in range(len(llama_instance_list)): - llama_instance_list[i]._tag_kv_ios( + quant_attrs = llama_instance_list[i]._tag_ios( edge_progs[i].exported_program.graph_module, - kv_type=kv_type, - sharding_type=sharding_type, + fixed_point_type, ) - backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 0 + ) graph_names = ["prefill_forward", "kv_forward"] compiler_specs = [ generate_qnn_executorch_compiler_spec( @@ -468,15 +531,19 @@ def compile(args, pte_filename): ) for graph_name in graph_names ] + skip_node_op_set = {"llama.fallback.default"} exported_programs = [ - to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) + to_backend( + edge_prog.exported_program, + QnnPartitioner(compiler_specs[i], skip_node_op_set=skip_node_op_set), + ) for i, edge_prog in enumerate(edge_progs) ] + if args.num_sharding > 0: + for exported_program in exported_programs: + update_spill_fill_size(exported_program) executorch_config = ExecutorchBackendConfig( - passes=[ - BuildQuantIo(), - ], # For shared buffer, user must pass the memory address # which is allocated by RPC memory to executor runner. # Therefore, won't want to pre-allocate @@ -488,20 +555,124 @@ def compile(args, pte_filename): extract_delegate_segments=True, ) - prog_mgr = generate_multi_graph_program( - compiler_specs=compiler_specs[0], - exported_programs=exported_programs, - backend_config=executorch_config, - constant_methods=llama_instance_list[1].llama_meta, # kv method meta - ) - with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: - prog_mgr.write_to_file(file) + lower_module_dict = {name: [] for name in graph_names} + call_delegate_inputs_dict = {name: [] for name in graph_names} + call_delegate_node_name_dict = {name: [] for name in graph_names} + outputs_dict = {name: [] for name in graph_names} + input_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(exported_programs, graph_names): + for node in prog.graph_module.graph.nodes: + if ( + node.op == "call_function" + and "executorch_call_delegate" in node.name + ): + call_delegate_node_name_dict[graph_name].append(node.name) + call_delegate_inputs_list = [] + for arg in node.args: + if arg.op == "call_function": + while "getitem" not in arg.name: + arg = arg.args[0] + call_delegate_inputs_list.append( + (arg.args[0].name, arg.args[1]) + ) + elif arg.op == "placeholder": + call_delegate_inputs_list.append((arg.name, None)) + # No extra needs to do for get_attr node + call_delegate_inputs_dict[graph_name].append( + call_delegate_inputs_list + ) + elif node.op == "output": + for arg in node.args[0]: + outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) + + if args.num_sharding > 0: + bundle_progs_list = [] + for num in range(args.num_sharding - 1, -1, -1): + processed_bytes = [] + for prog, graph_name in zip(exported_programs, graph_names): + processed_bytes.append( + getattr( + prog.graph_module, f"lowered_module_{num}" + ).processed_bytes + ) + + call_delegate_node = [ + list(node.users.keys())[0] + for node in prog.graph_module.graph.nodes + if node.op == "get_attr" + and node.name == f"lowered_module_{num}" + ] + input_nodes_dict[graph_name] = [ + node + for node in call_delegate_node[0].args + if node.op == "placeholder" + ] + + prog_mgr, bundle_progs = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[ + 1 + ].llama_meta, # kv method meta + ) + bundle_progs_list.append(bundle_progs) + for graph_name in graph_names: + lower_module_dict[graph_name].append( + prog_mgr.exported_program(graph_name).graph_module._modules.get( + "lowered_module_0" + ) + ) + + exec_prog = generate_composite_llama_program( + graph_names=graph_names, + sample_inputs_list=sample_inputs_list, + lower_module_dict=lower_module_dict, + call_delegate_node_name_dict=call_delegate_node_name_dict, + call_delegate_inputs_dict=call_delegate_inputs_dict, + outputs_dict=outputs_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + exec_prog.write_to_file(file) + else: + processed_bytes = [] + input_nodes_dict = {name: [] for name in graph_names} + output_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(exported_programs, graph_names): + processed_bytes.append( + prog.graph_module.lowered_module_0.processed_bytes + ) + input_nodes_dict[graph_name] = [ + node + for node in prog.graph_module.graph.nodes + if node.op == "placeholder" + ] + output_nodes_dict[graph_name] = [ + node + for node in prog.graph_module.graph.nodes + if node.op == "output" + ] + + prog_mgr, _ = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + output_nodes_dict=output_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + prog_mgr.write_to_file(file) end_lowering_ts = time.time() logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") + return quant_attrs -def inference(args, pte_filename, pre_gen_pte=""): +def inference(args, quant_attrs, pte_filename, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" if args.model_mode == "prefill": @@ -524,6 +695,8 @@ def inference(args, pte_filename, pre_gen_pte=""): f"--eval_mode {eval_mode}", f"--temperature {args.temperature}", f"--system_prompt '{args.system_prompt}'", + f"--logits_scale {quant_attrs['scale']}", + f"--logits_offset {quant_attrs['zero_point']}", ] ) runner_cmd = " ".join( @@ -706,16 +879,42 @@ def main(): raise RuntimeError(f"No such model_mode {args.model_mode}.") if args.pre_gen_pte: - inference(args, pte_filename, args.pre_gen_pte) + quant_attrs = json.load( + open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt") + ) + inference(args, quant_attrs, pte_filename, args.pre_gen_pte) exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") if args.compile_only: - compile(args, pte_filename) + quant_attrs = compile(args, pte_filename) + if quant_attrs: + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") exit(f"Finish compile_only and save to {args.artifact}") try: - compile(args, pte_filename) - inference(args, pte_filename) + quant_attrs = compile(args, pte_filename) + if quant_attrs: + logging.info( + f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}" + ) + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + inference(args, quant_attrs, pte_filename) except Exception as e: if args.ip and args.port != -1: with Client((args.ip, args.port)) as conn: diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp index d05def243b..2af882580e 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp @@ -48,6 +48,8 @@ DEFINE_int32( eval_mode, 0, "0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)"); +DEFINE_double(logits_scale, 0.0, "Logits scale"); +DEFINE_int32(logits_offset, 0, "Logits offset"); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -56,6 +58,8 @@ int main(int argc, char** argv) { example::Runner runner( {FLAGS_model_path}, FLAGS_tokenizer_path.c_str(), + FLAGS_logits_scale, + FLAGS_logits_offset, FLAGS_temperature, FLAGS_eval_mode); std::vector buf; diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp index ccf386309c..941ff97685 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp @@ -134,7 +134,7 @@ void HybridMemory::init_io() { auto init_kv = [&]() { ptr->kv_logits.resize(vocab_size_); - ptr->kv_attention_mask.resize((kv_cache_len_ + 1), -255); + ptr->kv_attention_mask.resize((kv_cache_len_ + 1), 0); ptr->k_cache.reserve(num_layers_); for (int layer = 0; layer < num_layers_; layer++) { ptr->k_cache.emplace_back(); @@ -315,9 +315,9 @@ void HybridMemory::prepare_prefill_io( for (int i = 0; i < prefill_cache_len_; ++i) { for (int j = 0; j < prefill_cache_len_; ++j) { if (i < j) { - ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = -255; - } else { ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 0; + } else { + ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 65535; } } } @@ -458,7 +458,7 @@ void HybridMemory::update_kv_io( // update position_ids ptr->input_pos = static_cast(pos); // update causal mask for next token - ptr->kv_attention_mask[kv_cache_len_ - pos] = 0; + ptr->kv_attention_mask[kv_cache_len_ - pos] = 65535; // update v_cache auto& v_cache_in = v_cache_in_[kv_forward_name_]; @@ -496,4 +496,13 @@ void HybridMemory::update_kv_io( } } +void HybridMemory::update_prefill_io( + int64_t cur_token, + int64_t pos, + std::vector>& output_tensors) { + (void)output_tensors; + IO* ptr = static_cast(data_ptr_.get()); + ptr->prefill_input_toks[pos] = static_cast(cur_token); +} + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h index ca3a884887..bb107ffd77 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h @@ -43,6 +43,10 @@ class Memory { int64_t cur_token, int64_t pos, std::vector>& output_tensors) = 0; + virtual void update_prefill_io( + int64_t cur_token, + int64_t pos, + std::vector>& output_tensors) = 0; void* get_mutable_ptr(); std::vector get_input_tensors( int shard_index, @@ -97,17 +101,22 @@ class HybridMemory : public Memory { int64_t pos, std::vector>& output_tensors) override; + void update_prefill_io( + int64_t cur_token, + int64_t pos, + std::vector>& output_tensors) + override; struct IO { int32_t input_tok; int32_t input_pos; std::vector>> k_cache; std::vector> v_cache; std::vector> k_cache_out; - std::vector kv_attention_mask; - std::vector kv_logits; + std::vector kv_attention_mask; + std::vector kv_logits; std::vector prefill_input_toks; - std::vector prefill_atten_mask; - std::vector prefill_logits; + std::vector prefill_atten_mask; + std::vector prefill_logits; }; private: diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index e87240dfdf..02a53861b8 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -40,11 +40,15 @@ std::string statsToJsonString(const Runner::Stats& stats); Runner::Runner( const std::vector& models_path, const std::string& tokenizer_path, + const float logits_scale, + const int32_t logits_offset, const float temperature, const int eval_mode) : n_bos_(1), n_eos_(1), tokenizer_path_(tokenizer_path), + logits_scale_(logits_scale), + logits_offset_(logits_offset), temperature_(temperature), eval_mode_(static_cast(eval_mode)) { for (size_t i = 0; i < models_path.size(); ++i) { @@ -205,13 +209,23 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) { return res; } -template -int32_t Runner::logitsToToken(const Tensor& logits_tensor) { - T* logits = logits_tensor.mutable_data_ptr(); - +int32_t Runner::logitsToToken(const Tensor& logits_tensor, int64_t pos) { + static std::vector logits_f(vocab_size_); + const uint16_t* logits = logits_tensor.data_ptr(); // Since the logits are for all tokens, get the last token probabilities - T* logits_last = logits; - return sampler_->sample(logits_last); + auto* logits_last = logits; + + // offset to the meaningful logit we want. + if (logits_tensor.sizes().data()[1] > 1) { + auto vocab_size = logits_tensor.size(2); + logits_last += pos * vocab_size; + } + + // dequantize + for (int i = 0; i < vocab_size_; i++) { + logits_f[i] = (logits_last[i] - logits_offset_) * logits_scale_; + } + return sampler_->sample(logits_f.data()); } void Runner::run_model_step( @@ -266,11 +280,11 @@ Error Runner::generate( if (!system_prompt.empty()) { prompt_.append("<|start_header_id|>system<|end_header_id|>\n\n"); prompt_.append(system_prompt); - prompt_.append("<|eot_id|>\n"); + prompt_.append("<|eot_id|>"); } prompt_.append("<|start_header_id|>user<|end_header_id|>\n\n"); prompt_.append(prompt); - prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>"); + prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); if (token_callback) { token_callback("<|begin_of_text|>"); @@ -308,32 +322,48 @@ Error Runner::generate( auto prefill_execute = [&](const std::string& method_name) { for (int i = 0; i < num_prompt_tokens; i++) { ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]); - auto piece_res = tokenizer_->decode(prompt_tokens[i], prompt_tokens[i]); - token_callback(piece_res.get()); } - // inference - run_model_step(method_name, inputs[method_name]); - Tensor& logits_tensor = output_tensors[method_name].back()[0]; - // offset to the meaningful logit we want. - float* logits = logits_tensor.mutable_data_ptr() + - (num_prompt_tokens - 1) * vocab_size_; - prev_token = prompt_tokens[num_prompt_tokens - 1]; - long sample_start_time_ms = time_in_ms(); - cur_token = sampler_->sample(logits); - stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; - stats_.first_token_ms = time_in_ms(); - stats_.prompt_eval_end_ms = time_in_ms(); - auto piece_res = tokenizer_->decode(prev_token, cur_token); - ET_CHECK(piece_res.ok()); if (token_callback) { - token_callback(piece_res.get().c_str()); + token_callback(prompt_); + } + + pos = num_prompt_tokens - 1; + cur_token = prompt_tokens[pos]; + while (pos < seq_len - 1) { + // inference + run_model_step(method_name, inputs[method_name]); + Tensor& logits_tensor = output_tensors[method_name].back()[0]; + prev_token = cur_token; + long sample_start_time_ms = time_in_ms(); + cur_token = logitsToToken(logits_tensor, pos); + stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; + + io_mem_->update_prefill_io(cur_token, ++pos, output_tensors[method_name]); + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + + if (pos == num_prompt_tokens) { + stats_.first_token_ms = time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); + } + + if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { + ET_LOG(Info, "\nReached to the end of generation"); + break; + } + // prefill model inferences once for prompt in the hybrid mode + if (eval_mode_ == EvalMode::kHybrid) { + break; + } } - pos += num_prompt_tokens; }; auto kv_execute = [&](const std::string& method_name) { ptr->input_tok = static_cast(cur_token); - ptr->kv_attention_mask[kv_cache_len_] = 0; + ptr->kv_attention_mask[kv_cache_len_] = 65535; while (pos < seq_len - 1) { // inference run_model_step(method_name, inputs[method_name]); @@ -347,10 +377,9 @@ Error Runner::generate( stats_.prompt_eval_end_ms = time_in_ms(); } } - prev_token = cur_token; long sample_start_time_ms = time_in_ms(); - cur_token = logitsToToken(logits_tensor); + cur_token = logitsToToken(logits_tensor, pos); stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; if (pos < num_prompt_tokens - 1) { diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h index 79b8370982..75ad640219 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h @@ -29,6 +29,8 @@ class Runner { explicit Runner( const std::vector& models_path, const std::string& tokenizer_path, + const float logits_scale, + const int32_t logits_offset, const float temperature, const int eval_mode); @@ -73,8 +75,9 @@ class Runner { private: template T getMetadataHelper(std::string method_name, T default_val); - template - int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor); + int32_t logitsToToken( + const executorch::aten::Tensor& logits_tensor, + int64_t pos); void run_model_step( const std::string& method_name, std::vector>& inputs); @@ -90,6 +93,8 @@ class Runner { const int32_t n_eos_; std::vector> modules_; std::string tokenizer_path_; + float logits_scale_; + int32_t logits_offset_; float temperature_; std::unique_ptr tokenizer_; std::unique_ptr sampler_; diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 88247d2a27..d08e68fa73 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder( warnings.warn( "Mutation on a buffer in the model is detected. ExecuTorch assumes " "buffers that are mutated in the graph have a meaningless initial state, " - "only the shape and dtype will be serialized.", + "only the shape and dtype will be serialized, unless a pass which sets " + 'meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.', UserWarning, stacklevel=1, ) @@ -1602,6 +1603,7 @@ def placeholder( """ spec = self.node.meta["spec"] constant_tag = self.node.meta.get("constant_tag", None) + initialize_buffer = self.node.meta.get("et_init_buffer", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): @@ -1655,7 +1657,10 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - spec.const = not (is_user_input or is_mutable_buffer) + if initialize_buffer and is_mutable_buffer: + spec.const = True + else: + spec.const = not (is_user_input or is_mutable_buffer) evalue = ( self._tensor_spec_to_evalue(spec, constant_tag) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index a645fa5377..0da4085914 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -9,6 +9,7 @@ import typing import unittest from contextlib import contextmanager +from copy import deepcopy from typing import List, Optional, Tuple import executorch.exir as exir @@ -31,6 +32,7 @@ from executorch.exir.error import InternalError from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.constant_prop_pass import constant_prop_pass +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program # noqa from executorch.exir.schema import ( @@ -56,6 +58,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) +from executorch.runtime import Runtime from functorch.experimental import control_flow from torch import nn @@ -243,6 +246,56 @@ def forward(self, x): ) self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) + def test_initialized_mutable_buffer(self): + """Test that mutable buffers can hold meaningful initialized state.""" + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Mutable buffer with non-empty initial state. + self.register_buffer("cache_pos", torch.arange(0, 10)) + + def forward(self, x): + self.cache_pos.add_(1) + return self.cache_pos + + m = TestModule() + example_inputs = (torch.ones(10),) + ep = torch.export.export(m, example_inputs) + edge = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + + # Save a copy of the edge program since to_executorch is + # stateful to some degree. + edge_copy = deepcopy(edge) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program_init_pass = edge.to_executorch(config=et_config) + et_program_regular = edge_copy.to_executorch() + + runtime = Runtime.get() + program_init_pass = runtime.load_program(et_program_init_pass.buffer) + method_init_pass = program_init_pass.load_method("forward") + + program_regular = runtime.load_program(et_program_regular.buffer) + method_regular = program_regular.load_method("forward") + + # Test that the mutable buffer is initialized. + torch.allclose( + method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) + ) + # Test that the mutable buffer is uninitialized and starts with default zeros, + # we test equality with torch.ones because of the mutation += 1 in the model forward. + torch.allclose( + method_regular.execute((example_inputs))[0], + torch.ones(10, dtype=torch.int64), + ) + def test_int_list_input(self): class M(torch.nn.Module): def forward(self, x, y, z): diff --git a/exir/passes/init_mutable_pass.py b/exir/passes/init_mutable_pass.py new file mode 100644 index 0000000000..e6778259d0 --- /dev/null +++ b/exir/passes/init_mutable_pass.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +from executorch.exir.pass_base import ExportPass + + +class InitializedMutableBufferPass(ExportPass): + """ + If a buffer has a name that within a specified list, set meta["et_init_buffer"] + to True, which provides the mutable buffer with an initialized state. + + As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))` + when patterns = ["cache_pos"] would have its initial state set instead of being + left uninitialized by default. + """ + + def __init__(self, patterns: List[str]) -> None: + super().__init__() + self.patterns = patterns + + def placeholder(self, name: str, arg, meta): + for pattern in self.patterns: + if pattern in name: + meta["et_init_buffer"] = True + + return super().placeholder(name, arg, meta) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 26043da591..f8a209d86b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -27,6 +27,7 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass @@ -425,21 +426,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_executorch(self) -> "LLMEdgeManager": + def to_executorch( + self, passes: Optional[List[PassType]] = None + ) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. """ assert self.edge_manager, "Need to run export_to_edge() first" + to_executorch_passes = [ + # If there are Linear operations left in the graph, let's execute + # them with the optimized op_linear rather than materializing a + # transpose followed by a regular op_mm. + ConvertToLinearPass(), + QuantFusionPass(), + ] + if passes: + to_executorch_passes.extend(passes) + self.export_program = self.edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - # If there are Linear operations left in the graph, let's execute - # them with the optimized op_linear rather than materializing a - # transpose followed by a regular op_mm. - ConvertToLinearPass(), - QuantFusionPass(), - ], + passes=to_executorch_passes, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 82ee1febf4..3ecf0b2b4b 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -11,6 +11,8 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.modules.attention import ( MultiHeadAttention as ETMultiHeadAttention, ) @@ -31,6 +33,7 @@ def setUp(self): self.num_kv_heads = 8 self.head_dim = 64 self.max_seq_len = 128 + self.encoder_max_seq_len = 128 self.rope_base = 500_000 self.scale_factor = 32 @@ -84,16 +87,26 @@ def setUp(self): max_seq_len=self.max_seq_len, ) self.et_mha.load_state_dict(self.tt_mha.state_dict()) + # Common inputs. seq_len = 10 self.x = torch.randn(1, seq_len, self.embed_dim) + self.y = torch.randn(1, seq_len, self.embed_dim) self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len] - seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) - self.dynamic_shapes = ( - {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, - {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, - {0: torch.export.Dim.STATIC, 1: seq_len_dim}, - ) + self.seq_len_dim = torch.export.Dim("seq_len", min=1, max=self.max_seq_len) + self.dynamic_shapes = { + "x": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + }, + "y": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + }, + "input_pos": {0: torch.export.Dim.STATIC, 1: self.seq_len_dim}, + } self.causal_mask = torch.tril( torch.ones( size=(self.max_seq_len, self.max_seq_len), @@ -108,13 +121,13 @@ def test_attention_eager(self): assert_close(et_res, tt_res) # test with kv cache - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) self.et_mha.reset_cache() self.tt_mha.reset_cache() @@ -125,7 +138,7 @@ def test_attention_eager(self): self.x, self.x, input_pos=self.input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # test kv cache read. Input pos can be [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) @@ -142,12 +155,12 @@ def test_attention_export(self): # Self attention. # test with kv cache - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) with torch.no_grad(): et_mha_ep = torch.export.export( self.et_mha, - (self.x, self.x), + (self.x, self.y), kwargs={"input_pos": self.input_pos}, dynamic_shapes=self.dynamic_shapes, strict=True, @@ -164,8 +177,8 @@ def test_attention_aoti(self): # Self attention. # test with kv cache - self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) with torch.no_grad(): so = torch._export.aot_compile( self.et_mha, @@ -187,14 +200,13 @@ def test_attention_aoti(self): def test_attention_executorch(self): # Self attention. - # TODO: Fix kv cache - # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) with torch.no_grad(): et_mha_ep = torch.export.export( self.et_mha, - (self.x, self.x), + (self.x, self.y), kwargs={"input_pos": self.input_pos}, dynamic_shapes=self.dynamic_shapes, strict=True, @@ -202,9 +214,15 @@ def test_attention_executorch(self): et_program = to_edge( et_mha_ep, compile_config=EdgeCompileConfig( - _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg] + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, ), - ).to_executorch() + ).to_executorch( + config=ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + ) + runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") @@ -215,32 +233,121 @@ def test_attention_executorch(self): def test_attention_torch_cond_eager(self): # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition. - # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan. + # For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan. self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) - # mask mask = self.causal_mask[self.input_pos, :] - # First run - et_res = self.et_mha( - self.x, self.x, mask=mask, input_pos=self.input_pos - ) # Self attention with input pos. - tt_res = self.tt_mha( - self.x, self.x, mask=mask, input_pos=self.input_pos - ) # Self attention with input pos. + # First run. + et_res = self.et_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) - # Second run test kv cache read. Input pos is [10, 11, ..., 19] + # Second run tests kv cache read. Input pos is [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) empty_y = torch.full_like(self.x, torch.nan) mask = self.causal_mask[next_input_pos, :] - et_res = self.et_mha( + et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) + + assert_close(et_res, tt_res) + + def test_attention_torch_cond_export(self): + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + mask = self.causal_mask[self.input_pos, :] + dynamic_shapes = { + **self.dynamic_shapes, + **{ + "mask": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + } + }, + } + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.y), + kwargs={ + "mask": mask, + "input_pos": self.input_pos, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + # First run. + et_res = et_mha_ep.module()(self.x, self.y, mask=mask, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) + + assert_close(et_res, tt_res) + + # Second run tests kv cache read. Input pos is [10, 11, ..., 19] + next_input_pos = torch.arange(10, 20).unsqueeze(0) + empty_y = torch.full_like(self.y, torch.nan) + mask = self.causal_mask[next_input_pos, :] + et_res = et_mha_ep.module()( self.x, empty_y, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. - tt_res = self.tt_mha( - self.x, None, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. + ) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) assert_close(et_res, tt_res) + + def test_attention_torch_cond_executorch(self): + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + mask = self.causal_mask[self.input_pos, :] + dynamic_shapes = { + **self.dynamic_shapes, + **{ + "mask": { + 0: torch.export.Dim.STATIC, + 1: self.seq_len_dim, + 2: torch.export.Dim.STATIC, + } + }, + } + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.y), + kwargs={ + "mask": mask, + "input_pos": self.input_pos, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + et_program = to_edge( + et_mha_ep, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ).to_executorch( + config=ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + ) + + # First run. + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + et_res = method.execute((self.x, self.y, mask, self.input_pos)) + tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos) + + assert_close(et_res[0], tt_res) + + # Second run tests kv cache read. Input pos is [10, 11, ..., 19] + next_input_pos = torch.arange(10, 20).unsqueeze(0) + empty_y = torch.full_like(self.y, torch.nan) + mask = self.causal_mask[next_input_pos, :] + et_res = method.execute((self.x, empty_y, mask, next_input_pos)) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) + + assert_close(et_res[0], tt_res) diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py new file mode 100644 index 0000000000..6029a03882 --- /dev/null +++ b/extension/llm/modules/test/test_kv_cache.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest +from typing import Callable, Tuple + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass + +from executorch.extension.export_util.utils import save_pte_program +from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) +from executorch.runtime import Runtime +from torch.testing import assert_close +from torchtune.modules.kv_cache import KVCache + + +def generate_cache_inputs( + seq_len: int, + batch_size: int = 1, + num_kv_heads: int = 64, + head_dim: int = 8, +) -> Tuple[torch.Tensor, ...]: + """Helper to generate k_val and v_val for both et and tt caches.""" + k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + + # For torchtune, the kv cache takes in transposed k and v. + k_val_trans = k_val.transpose(1, 2) + v_val_trans = v_val.transpose(1, 2) + + return (k_val, v_val, k_val_trans, v_val_trans) + + +class KVCacheTest(unittest.TestCase): + def setUp(self): + self.batch_size = 1 + self.max_seq_len = 10 + self.num_kv_heads = 1 # For testing purposes, usually this is 64. + self.head_dim = 8 + self.dtype = torch.float + + self.tt_kv_cache = KVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + ) + self.et_kv_cache = InferenceKVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + transpose_cache=False, + ) + + def _test_kv_cache(self, et_cache_module: Callable): + """ + Given an executorch kv cache anywhere along the export chain, compare it's results + against torchtune and run basic tests. + """ + prefill_seq_len = 3 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(et_res, tt_res_transposed) + + # Check the values are correct, all rows in the seq_len dim should be + # filled with 1s up to and including the 3rd. + et_k_cache = et_res[0] + for i in range(prefill_seq_len): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0) + + """Case 2: Token-by-token (seq_len = 0)""" + seq_len = 1 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(tt_res_transposed, et_res) + + # All rows should be filled with 1s up to 3 + 1th row. + et_k_cache = et_res[0] + for i in range(prefill_seq_len + 1): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + + self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) + + def export_kv_cache( + self, + kv_cache: torch.nn.Module, + ) -> torch.export.ExportedProgram: + # Wrapper since torch.export only exports forward(). + class EtCacheWrapper(torch.nn.Module): + def __init__(self, kv_cache: torch.nn.Module): + super().__init__() + self.kv_cache = kv_cache + + def forward(self, k_val: torch.Tensor, v_val: torch.Tensor): + return self.kv_cache.update(k_val, v_val) + + dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len) + exported_kv_cache = torch.export.export( + EtCacheWrapper(self.et_kv_cache), + ( + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + ), # 3 as example prefill seq_len. + dynamic_shapes={ + "k_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + "v_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + }, + ) + return exported_kv_cache + + def test_kv_cache_eager(self): + self._test_kv_cache(self.et_kv_cache.update) + + def test_kv_cache_export(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + self._test_kv_cache(exported_kv_cache.module()) + + def test_kv_cache_edge(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + self._test_kv_cache(edge_program._edge_programs["forward"].module()) + + def test_kv_cache_executorch(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) + + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + + # Since method.execute expects a tuple of args. + def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: + return method.execute((k_val, v_val)) + + self._test_kv_cache(wrapped_callable) + + def test_kv_cache_executorch_from_file(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) + + with tempfile.TemporaryDirectory() as tempdir: + pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir) + with open(pte_path, "rb") as f: + model_bytes = f.read() + loaded_et_program = _load_for_executorch_from_buffer(model_bytes) + + # Since method.execute expects a tuple of args. + def wrapped_callable( + k_val: torch.Tensor, v_val: torch.Tensor + ) -> torch.Tensor: + return loaded_et_program.forward((k_val, v_val)) + + self._test_kv_cache(wrapped_callable) diff --git a/install_requirements.py b/install_requirements.py index 3fca161a2c..b5dbd0dddf 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -132,7 +132,7 @@ def python_is_compatible(): # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION = "dev20241218" +NIGHTLY_VERSION = "dev20250104" # The pip repository that hosts nightly torch packages. TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu" diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 0024c676ef..833b37cfac 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -257,6 +257,8 @@ - op: mean.out +- op: mean.dtype_out + - op: min.dim_min - op: min.unary_out @@ -407,8 +409,6 @@ - op: upsample_bilinear2d.vec_out -- op: upsample_nearest2d.out - - op: upsample_nearest2d.vec_out - op: var.correction_out diff --git a/kernels/portable/cpu/op_mean.cpp b/kernels/portable/cpu/op_mean.cpp index aeb0d7f8ca..6730404dde 100644 --- a/kernels/portable/cpu/op_mean.cpp +++ b/kernels/portable/cpu/op_mean.cpp @@ -66,6 +66,14 @@ Tensor& mean_dim_out( return out; } +Tensor& mean_dtype_out( + KernelRuntimeContext& ctx, + const Tensor& in, + optional dtype, + Tensor& out) { + return mean_dim_out(ctx, in, ArrayRef(), false, dtype, out); +} + } // namespace native } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/op_upsample_nearest2d.cpp b/kernels/portable/cpu/op_upsample_nearest2d.cpp new file mode 100644 index 0000000000..93a88588d8 --- /dev/null +++ b/kernels/portable/cpu/op_upsample_nearest2d.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::SizesType; + +namespace { +template +void upsample_nearest2d_kernel_impl( + const Tensor& in, + const float scale_h, + const float scale_w, + Tensor& out) { + const auto in_data = in.const_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + auto in_plane = in_data; + for (auto n = 0; n < out.size(0); n++) { + for (auto c = 0; c < out.size(1); c++) { + for (auto h = 0; h < out.size(2); h++) { + for (auto w = 0; w < out.size(3); w++) { + const auto in_h = + nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]); + const auto in_w = + nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]); + + *out_data = in_plane[in_h * in.strides()[2] + in_w * in.strides()[3]]; + out_data++; + } + } + + in_plane += in.strides()[1]; + } + } +} +} // namespace + +Tensor& upsample_nearest2d_vec_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out) { + // Preconditions (checked in check_..._args): + // In and out tensors have same dtype. + // In and out tensors are rank 4 and have same dim[0] and dim[1]. + // In and out tensors are default dim order (NCHW). + ET_KERNEL_CHECK( + ctx, + check_upsample_nearest2d_args(in, output_size, scale_factors, out), + InvalidArgument, + out); + + double scale_h, scale_w; + + ET_KERNEL_CHECK_MSG( + ctx, + resize_upsample_2d( + in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor"); + + const auto kernel_scale_h = area_pixel_compute_scale( + in.sizes()[2], out.sizes()[2], false, scale_h); + const auto kernel_scale_w = area_pixel_compute_scale( + in.sizes()[3], out.sizes()[3], false, scale_w); + + ET_SWITCH_REAL_TYPES( + in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() { + upsample_nearest2d_kernel_impl( + in, kernel_scale_h, kernel_scale_w, out); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/reduce_util.cpp b/kernels/portable/cpu/util/reduce_util.cpp index 08237e07b9..884c10b813 100644 --- a/kernels/portable/cpu/util/reduce_util.cpp +++ b/kernels/portable/cpu/util/reduce_util.cpp @@ -386,6 +386,7 @@ bool check_mean_dim_args( check_reduction_args(in, dim_list, keepdim, dtype, out)); if (dtype) { + ET_LOG(Info, "dtype is %hhd", static_cast(dtype.value())); ET_LOG_AND_RETURN_IF_FALSE(torch::executor::isFloatingType(dtype.value())); ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value()); } else { diff --git a/kernels/portable/cpu/util/upsample_util.cpp b/kernels/portable/cpu/util/upsample_util.cpp index 2082b632d9..15cf28b663 100644 --- a/kernels/portable/cpu/util/upsample_util.cpp +++ b/kernels/portable/cpu/util/upsample_util.cpp @@ -46,6 +46,14 @@ bool check_upsample_bilinear2d_args( return check_upsample_2d_common_args(in, output_size, scale_factors, out); } +bool check_upsample_nearest2d_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out) { + return check_upsample_2d_common_args(in, output_size, scale_factors, out); +} + Error resize_upsample_2d( const Tensor& in, const exec_aten::OptionalArrayRef& output_size, diff --git a/kernels/portable/cpu/util/upsample_util.h b/kernels/portable/cpu/util/upsample_util.h index 86ae8e77ec..785b6f1cc8 100644 --- a/kernels/portable/cpu/util/upsample_util.h +++ b/kernels/portable/cpu/util/upsample_util.h @@ -28,6 +28,12 @@ bool check_upsample_bilinear2d_args( const exec_aten::OptionalArrayRef& scale_factors, Tensor& out); +bool check_upsample_nearest2d_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out); + Error resize_upsample_2d( const Tensor& in, const exec_aten::OptionalArrayRef& output_size, @@ -127,5 +133,17 @@ inline void compute_source_index_and_lambda( } } +// Ported from aten/src/ATen/native/UpSample.h +inline int64_t nearest_neighbor_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + // Index computation matching OpenCV INTER_NEAREST + // which is buggy and kept for BC + const int64_t src_index = + std::min(static_cast(floorf(dst_index * scale)), input_size - 1); + return src_index; +} + } // namespace executor } // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index ed203f9487..3221b8fe34 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -577,6 +577,11 @@ - arg_meta: null kernel_name: torch::executor::mean_dim_out +- op: mean.dtype_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dtype_out + - op: min.dim_min kernels: - arg_meta: null @@ -922,6 +927,11 @@ - arg_meta: null kernel_name: torch::executor::upsample_bilinear2d_vec_out +- op: upsample_nearest2d.vec_out + kernels: + - arg_meta: null + kernel_name: torch::executor::upsample_nearest2d_vec_out + - op: var.correction_out kernels: - arg_meta: null diff --git a/kernels/portable/test/TARGETS b/kernels/portable/test/TARGETS index 6255c37cbc..adf6636be4 100644 --- a/kernels/portable/test/TARGETS +++ b/kernels/portable/test/TARGETS @@ -21,6 +21,7 @@ runtime.cxx_library( deps = [ "//executorch/extension/aten_util:aten_bridge", "//executorch/kernels/portable/cpu:op_upsample_bilinear2d", + "//executorch/kernels/portable/cpu:op_upsample_nearest2d", "//executorch/runtime/core/exec_aten:lib", ], external_deps = [ @@ -40,3 +41,16 @@ python_unittest( "//caffe2:torch", ], ) + +python_unittest( + name = "op_upsample_nearest2d_test", + srcs = [ + "op_upsample_nearest2d_test.py", + ], + preload_deps = [ + ":aot_ops_test_lib", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/kernels/portable/test/op_upsample_nearest2d_test.py b/kernels/portable/test/op_upsample_nearest2d_test.py new file mode 100644 index 0000000000..23a6bb54fb --- /dev/null +++ b/kernels/portable/test/op_upsample_nearest2d_test.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import itertools +import unittest + +from typing import Optional, Sequence + +import torch + + +class UpsampleNearest2dTest(unittest.TestCase): + def run_upsample_test( + self, + inp: torch.Tensor, + output_size: Optional[Sequence[int]] = None, + scale_factors: Optional[Sequence[float]] = None, + atol=1e-7, + ) -> None: + aten_result = torch.nn.functional.interpolate( + inp, + size=output_size, + mode="nearest", + scale_factor=scale_factors, + ) + et_result = torch.zeros_like(aten_result) + et_result = torch.ops.et_test.upsample_nearest2d( + inp, + output_size=output_size, + scale_factors=scale_factors, + out=et_result, + ) + self.assertTrue( + torch.allclose(et_result, aten_result, atol=atol), + msg=f"ET: {et_result} \n ATen: {aten_result} \n Error: {et_result.to(torch.float) - aten_result.to(torch.float)}", + ) + + def test_upsample_nearest2d_aten_parity_f32(self): + N = [1, 2] + C = [1, 3] + H = [1, 3, 50, 1001] + W = [1, 2, 62, 1237] + OUT_H = [5, 21] + OUT_W = [7, 31] + + for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W): + input = torch.randn(n, c, h, w) + self.run_upsample_test(input, output_size=(out_h, out_w)) + self.run_upsample_test(input, scale_factors=(out_h / h, out_w / w)) + + def test_upsample_nearest2d_aten_parity_u8(self): + N = [1, 2] + C = [1, 3] + H = [1, 3, 50, 1001] + W = [1, 2, 62, 1237] + OUT_H = [5, 21] + OUT_W = [7, 31] + + for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W): + input = torch.randint(0, 255, (n, c, h, w), dtype=torch.uint8) + self.run_upsample_test(input, output_size=(out_h, out_w), atol=1) + self.run_upsample_test( + input, + scale_factors=(out_h / h, out_w / w), + atol=2, + ) diff --git a/kernels/portable/test/register_ops_aot_for_test.cpp b/kernels/portable/test/register_ops_aot_for_test.cpp index c8b63e736a..40d4bb79ed 100644 --- a/kernels/portable/test/register_ops_aot_for_test.cpp +++ b/kernels/portable/test/register_ops_aot_for_test.cpp @@ -47,6 +47,31 @@ Tensor& upsample_bilinear2d_vec_out_no_context( return ret; } + +Tensor& upsample_nearest2d_vec_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out); + +Tensor& upsample_nearest2d_vec_out_no_context( + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out) { + KernelRuntimeContext ctx; + auto& ret = + upsample_nearest2d_vec_out(ctx, in, output_size, scale_factors, out); + + if (ctx.failure_state() != Error::Ok) { + throw std::runtime_error( + std::string("Kernel failed with error: ") + + std::to_string((int)ctx.failure_state())); + } + + return ret; +} // NOLINTEND(facebook-hte-ConstantArgumentPassByValue, // facebook-hte-ParameterMightThrowOnCopy) @@ -54,6 +79,9 @@ TORCH_LIBRARY(et_test, m) { m.def( "upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(upsample_bilinear2d_vec_out_no_context, 4)); + m.def( + "upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", + WRAP_TO_ATEN(upsample_nearest2d_vec_out_no_context, 3)); } } // namespace native diff --git a/kernels/test/op_mean_test.cpp b/kernels/test/op_mean_test.cpp index 9821cb6b47..c5ba00b20e 100644 --- a/kernels/test/op_mean_test.cpp +++ b/kernels/test/op_mean_test.cpp @@ -9,7 +9,7 @@ #include // Declares the operator #include #include -#include +#include #include #include #include @@ -22,6 +22,7 @@ using exec_aten::ArrayRef; using exec_aten::optional; using exec_aten::ScalarType; using exec_aten::Tensor; +using executorch::runtime::Error; using torch::executor::testing::TensorFactory; class OpMeanOutTest : public OperatorTest { @@ -36,6 +37,13 @@ class OpMeanOutTest : public OperatorTest { context_, self, dim, keepdim, dtype, out); } + Tensor& op_mean_dtype_out( + const Tensor& self, + optional dtype, + Tensor& out) { + return torch::executor::aten::mean_outf(context_, self, dtype, out); + } + template void test_mean_dim_out_invalid_dimensions() { TensorFactory tf_in; @@ -466,3 +474,68 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) { op_mean_out(x, ArrayRef{1}, false, ScalarType::Float, out); EXPECT_TENSOR_CLOSE(out, expected_result); } + +TEST_F(OpMeanOutTest, DTypeOutFloatValid) { + TensorFactory tf; + + Tensor x = tf.make( + {10, 10}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + Tensor expected_result = tf.make({}, {1.0}); + + Tensor out = tf.zeros({}); + Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +TEST_F(OpMeanOutTest, DTypeOutFloatToBoolInvalid) { + TensorFactory tf; + + Tensor x = tf.make( + {10, 10}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + Tensor expected_result = tf.make({}, {1.0}); + + Tensor out = tf.zeros({}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_mean_dtype_out(x, ScalarType::Bool, out)); +} + +TEST_F(OpMeanOutTest, DTypeOutFloatInfinity) { + TensorFactory tf; + + Tensor x = tf.make({2, 1}, {INFINITY, INFINITY}); + Tensor expected_result = tf.make({}, {INFINITY}); + + Tensor out = tf.zeros({}); + + Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +TEST_F(OpMeanOutTest, DTypeOutFloatNAN) { + TensorFactory tf; + + Tensor x = tf.make({2, 1}, {NAN, INFINITY}); + Tensor expected_result = tf.make({}, {NAN}); + + Tensor out = tf.zeros({}); + + Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} diff --git a/kernels/test/op_upsample_nearest2d_test.cpp b/kernels/test/op_upsample_nearest2d_test.cpp new file mode 100644 index 0000000000..1230168800 --- /dev/null +++ b/kernels/test/op_upsample_nearest2d_test.cpp @@ -0,0 +1,408 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +#include + +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::SupportedFeatures; +using torch::executor::testing::TensorFactory; + +#ifdef USE_ATEN_LIB +template +using OptionalArrayRef = std::optional>; +#else +using exec_aten::OptionalArrayRef; +#endif + +class OpUpsampleNearest2dTest : public OperatorTest { + protected: + Tensor& op_upsample_nearest2d_out( + const Tensor& in, + const OptionalArrayRef output_size, + const OptionalArrayRef scale_factors, + Tensor& out) { + return torch::executor::aten::upsample_nearest2d_outf( + context_, in, output_size, scale_factors, out); + } + + template + void test_upsample_nearest2d_dtype() { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 2}, {1, 2, 3, 4}); + std::array output_size = {4, 4}; + auto out = tf.zeros({1, 1, 4, 4}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = + tf.make({1, 1, 4, 4}, {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4}); + + EXPECT_TENSOR_CLOSE(out, expected); + } +}; + +TEST_F(OpUpsampleNearest2dTest, SmokeTest) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + std::array output_size = {4, 4}; + auto out = tf.zeros({1, 1, 4, 4}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out); + + const auto expected = tf.make( + {1, 1, 4, 4}, + { + 0.1, + 0.1, + 0.2, + 0.2, + 0.1, + 0.1, + 0.2, + 0.2, + 1.1, + 1.1, + 1.2, + 1.2, + 1.1, + 1.1, + 1.2, + 1.2, + }); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, SmokeTestScale) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + auto out = tf.zeros({1, 1, 4, 4}); + std::array scale_factors = {2, 2}; + + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out); + + const auto expected = tf.make( + {1, 1, 4, 4}, + { + 0.1, + 0.1, + 0.2, + 0.2, + 0.1, + 0.1, + 0.2, + 0.2, + 1.1, + 1.1, + 1.2, + 1.2, + 1.1, + 1.1, + 1.2, + 1.2, + }); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, UpsampleSimpleFractional) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + std::array output_size = {5, 9}; + auto out = tf.zeros({1, 1, 5, 9}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out); + + const auto expected = tf.make( + {1, 1, 5, 9}, {0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, + 0.2, 0.2, 0.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2, + 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, UpsampleSimpleFractionalScale) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + auto out = tf.zeros({1, 1, 5, 9}); + std::array scale_factors = {5 / 2.0, 9 / 2.0}; + + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out); + + const auto expected = tf.make( + {1, 1, 5, 9}, {0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, + 0.2, 0.2, 0.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2, + 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, MultiBatchAndChannel) { + TensorFactory tf; + const auto input = tf.make( + {2, 2, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + 2.1, + 2.2, + 3.1, + 3.2, + 4.1, + 4.2, + 5.1, + 5.2, + 6.1, + 6.2, + 7.1, + 7.2, + }); + std::array output_size = {4, 4}; + auto out = tf.zeros({2, 2, 4, 4}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out); + + const auto expected = tf.make( + {2, 2, 4, 4}, + { + 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 1.1, 1.1, 1.2, 1.2, 1.1, + 1.1, 1.2, 1.2, 2.1, 2.1, 2.2, 2.2, 2.1, 2.1, 2.2, 2.2, 3.1, 3.1, + 3.2, 3.2, 3.1, 3.1, 3.2, 3.2, 4.1, 4.1, 4.2, 4.2, 4.1, 4.1, 4.2, + 4.2, 5.1, 5.1, 5.2, 5.2, 5.1, 5.1, 5.2, 5.2, 6.1, 6.1, 6.2, 6.2, + 6.1, 6.1, 6.2, 6.2, 7.1, 7.1, 7.2, 7.2, 7.1, 7.1, 7.2, 7.2, + }); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, DType) { +#define TEST_ENTRY(ctype, dtype) \ + test_upsample_nearest2d_dtype(); \ + ET_FORALL_REAL_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpUpsampleNearest2dTest, MismatchedOutputSizeDies) { + if (SupportedFeatures::get()->output_resize) { + GTEST_SKIP() + << "The current kernel supports implicitly resizing output tensor"; + } + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 5}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, InvalidInputRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, InvalidOutputRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, MissingOutputSizeOrScaleDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_upsample_nearest2d_out(input, {}, {}, out)); +} + +TEST_F(OpUpsampleNearest2dTest, BothOutputSizeAndScaleDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array output_size = {1, 4}; + std::array scale_factors = {1, 2}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + OptionalArrayRef( + {scale_factors.data(), scale_factors.size()}), + out)); +} + +TEST_F(OpUpsampleNearest2dTest, MismatchedDTypeDies) { + TensorFactory tf; + TensorFactory tf2; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf2.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, ComputedOutputSizeMatchesExpected) { + // Computed output sizes (from input size * scales) must match PyTorch + // eager-mode - multiplied as double and cast (truncated) to an integral type. + // See + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/UpSample.cpp + TensorFactory tf; + + // Test case format: { in_h, in_w, scale_h, scale_w, out_h, out_w } + std::vector> + test_cases = { + {10, 10, 9.99999, 9.55, 99, 95}, + {10, 10, 9.99999999, 0.1, 99, 1}, + }; + + for (const auto& test_case : test_cases) { + const auto [in_h, in_w, scale_h, scale_w, out_h, out_w] = test_case; + + const auto input = tf.ones({1, 1, in_h, in_w}); + std::array scale_factors = {scale_h, scale_w}; + auto out = tf.zeros({1, 1, out_h, out_w}); + + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out); + + const auto expected = tf.ones({1, 1, out_h, out_w}); + + EXPECT_TENSOR_EQ(out, expected); + } +} + +TEST_F(OpUpsampleNearest2dTest, ZeroComputedOutputSizeDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array scale_factors = {1, 0.25}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef( + {scale_factors.data(), scale_factors.size()}), + out)); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index d4f5769dbe..db14173050 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -347,6 +347,7 @@ def define_common_targets(): _common_op_test("op_unbind_copy_test", ["aten", "portable"]) _common_op_test("op_unsqueeze_copy_test", ["aten", "portable"]) _common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"]) + _common_op_test("op_upsample_nearest2d_test", ["aten", "portable"]) _common_op_test("op_var_test", ["aten", "portable"]) _common_op_test("op_view_copy_test", ["aten", "portable"]) _common_op_test("op_where_test", ["aten", "portable"]) diff --git a/pyproject.toml b/pyproject.toml index ec08f58185..8b7ee5304b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ # 3 - Alpha # 4 - Beta # 5 - Production/Stable - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", diff --git a/runtime/backend/backend_init_context.h b/runtime/backend/backend_init_context.h index 051266662c..49bf4adbf9 100644 --- a/runtime/backend/backend_init_context.h +++ b/runtime/backend/backend_init_context.h @@ -20,6 +20,7 @@ class BackendInitContext final { public: explicit BackendInitContext( MemoryAllocator* runtime_allocator, + EventTracer* event_tracer = nullptr, const char* method_name = nullptr) : runtime_allocator_(runtime_allocator), method_name_(method_name) {} @@ -31,6 +32,15 @@ class BackendInitContext final { return runtime_allocator_; } + /** + * Returns a pointer (null if not installed) to an instance of EventTracer to + * do profiling/debugging logging inside the delegate backend. Users will need + * access to this pointer to use any of the event tracer APIs. + */ + EventTracer* event_tracer() { + return event_tracer_; + } + /** Get the loaded method name from ExecuTorch runtime. Usually it's * "forward", however, if there are multiple methods in the .pte file, it can * be different. One example is that we may have prefill and decode methods in @@ -44,6 +54,7 @@ class BackendInitContext final { private: MemoryAllocator* runtime_allocator_ = nullptr; + EventTracer* event_tracer_ = nullptr; const char* method_name_ = nullptr; }; diff --git a/runtime/core/exec_aten/util/dim_order_util.h b/runtime/core/exec_aten/util/dim_order_util.h index 4c5858fb5e..a1481ee387 100644 --- a/runtime/core/exec_aten/util/dim_order_util.h +++ b/runtime/core/exec_aten/util/dim_order_util.h @@ -9,6 +9,8 @@ #pragma once #include +#include +#include #include #include @@ -257,6 +259,36 @@ ET_NODISCARD inline Error stride_to_dim_order( } return Error::Ok; } + +/** + * Print a string representation of an ArrayRef of tensor sizes into a + * user-provided string buffer. If the user buffer is too small, the string + * will be truncated. The output is of the format (1,2,3,4). + * + * Note that we cannot use ArrayRef here due to a circular dependency (see + * above comments). + */ +template +inline void sizes_to_string( + char* output, + size_t output_size, + SizesType* sizes, + size_t rank) { + auto remaining_size = output_size; + for (auto i = 0; remaining_size > 0 && i < rank; i++) { + snprintf( + output, + remaining_size, + "%s%zd", + i == 0 ? "(" : ",", + static_cast(sizes[i])); + auto len = strlen(output); + output += len; + remaining_size -= len; + } + snprintf(output, remaining_size, ")"); +} + } // namespace runtime } // namespace executorch diff --git a/runtime/core/portable_type/half.h b/runtime/core/portable_type/half.h index fa40a80782..267d17bdba 100644 --- a/runtime/core/portable_type/half.h +++ b/runtime/core/portable_type/half.h @@ -15,7 +15,7 @@ #include #if defined(__GNUC__) || defined(__clang__) -#if defined(__aarch64__) +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) #ifndef __ARM_V8_ONLY__ #define NATIVE_FP16 1 #endif // __ARM_V8_ONLY__ diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index 689a37cb5d..f029777874 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -90,24 +90,46 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef new_sizes) { if (dim_ == 0) { return Error::Ok; } + switch (shape_dynamism_) { case TensorShapeDynamism::STATIC: - ET_CHECK_OR_RETURN_ERROR( - std::equal(sizes_, sizes_ + dim_, new_sizes.begin()), - NotSupported, - "Attempted to resize a static tensor"); + if (!std::equal(sizes_, sizes_ + dim_, new_sizes.begin())) { +#ifdef ET_LOG_ENABLED + std::array old_sizes_str, new_sizes_str; + + executorch::runtime::sizes_to_string( + old_sizes_str.data(), + old_sizes_str.size(), + sizes().data(), + sizes().size()); + executorch::runtime::sizes_to_string( + new_sizes_str.data(), + new_sizes_str.size(), + new_sizes.data(), + new_sizes.size()); +#endif + + ET_CHECK_OR_RETURN_ERROR( + false, + NotSupported, + "Attempted to resize a static tensor. Expected shape %s, but received %s.", + old_sizes_str.data(), + new_sizes_str.data()) + } + break; case TensorShapeDynamism::DYNAMIC_BOUND: // TODO(T175194371): Unbounded dynamic tensor resizing is not yet // supported: treat them as upper-bounded. case TensorShapeDynamism::DYNAMIC_UNBOUND: { const auto new_numel = compute_numel(new_sizes.data(), dim_); + ET_CHECK_OR_RETURN_ERROR( new_numel <= numel_bound_, NotSupported, - "Attempted to resize a bounded tensor with capacity of %zu elements to %zu elements.", - new_numel, - numel_bound_); + "Attempted to resize a bounded tensor with a maximum capacity of %zu elements to %zu elements.", + numel_bound_, + new_numel); if (strides_ && dim_order_) { auto error = diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index eb0a8c8bb6..db9417dd88 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -628,6 +628,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) { const auto& delegate = *delegates->Get(i); BackendInitContext backend_init_context( method_allocator, + /*event_tracer=*/event_tracer_, /*method_name=*/serialization_plan_->name()->c_str()); Error err = BackendDelegate::Init( delegate, program_, backend_init_context, &delegates_[i]); @@ -1061,7 +1062,7 @@ Error Method::execute_instruction() { // We know that instr_args_as_KernelCall is non-null because it was // checked at init time. auto op_index = instruction->instr_args_as_KernelCall()->op_index(); - auto op = serialization_plan_->operators()->Get(op_index); + ET_UNUSED auto op = serialization_plan_->operators()->Get(op_index); ET_LOG( Error, "KernelCall failed at instruction %zu:%zu in operator %s.%s: 0x%x", diff --git a/runtime/executor/method_meta.cpp b/runtime/executor/method_meta.cpp index f43398d0ab..5be486b4d8 100644 --- a/runtime/executor/method_meta.cpp +++ b/runtime/executor/method_meta.cpp @@ -116,6 +116,14 @@ Result MethodMeta::input_tag(size_t index) const { index, num_inputs); auto input_index = s_plan_->inputs()->Get(index); + size_t num_values = s_plan_->values()->size(); + ET_CHECK_OR_RETURN_ERROR( + input_index >= 0 && input_index < num_values, + InvalidProgram, + "internal value index %d out of range [0,%zu) for input %zu", + input_index, + num_values, + index); auto serialization_value = s_plan_->values()->Get(input_index); return get_tag(serialization_value, index); } @@ -132,6 +140,7 @@ Result MethodMeta::input_tensor_meta(size_t index) const { (size_t)tag.get(), index); auto input_index = s_plan_->inputs()->Get(index); + // input_index was already validated by input_tag(). auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor(); return TensorInfo( Span( @@ -156,8 +165,16 @@ Result MethodMeta::output_tag(size_t index) const { "index %zu out of range. num_outputs: %zu", index, num_outputs); - auto input_index = s_plan_->outputs()->Get(index); - auto serialization_value = s_plan_->values()->Get(input_index); + auto output_index = s_plan_->outputs()->Get(index); + size_t num_values = s_plan_->values()->size(); + ET_CHECK_OR_RETURN_ERROR( + output_index >= 0 && output_index < num_values, + InvalidProgram, + "internal value index %d out of range [0,%zu) for output %zu", + output_index, + num_values, + index); + auto serialization_value = s_plan_->values()->Get(output_index); return get_tag(serialization_value, index); } @@ -173,6 +190,7 @@ Result MethodMeta::output_tensor_meta(size_t index) const { (size_t)tag.get(), index); auto output_index = s_plan_->outputs()->Get(index); + // output_index was already validated by output_tag(). auto tensor_value = s_plan_->values()->Get(output_index)->val_as_Tensor(); return TensorInfo( diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index 78aa0a5173..94d230e8f8 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -74,7 +74,8 @@ Error register_kernels_internal(const Span kernels) { return Error::Internal; } // for debugging purpose - const char* lib_name = et_pal_get_shared_library_name(kernels.data()); + ET_UNUSED const char* lib_name = + et_pal_get_shared_library_name(kernels.data()); for (const auto& kernel : kernels) { // Linear search. This is fine if the number of kernels is small. diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 9f2a383402..f5ddae06b6 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1235,6 +1235,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:upsample_util", ], ), + op_target( + name = "op_upsample_nearest2d", + deps = [ + "//executorch/kernels/portable/cpu/util:upsample_util", + ], + ), op_target( name = "op_var", deps = [ diff --git a/test/size_test.cpp b/test/size_test.cpp index 88b605c3bf..1fab1e914e 100644 --- a/test/size_test.cpp +++ b/test/size_test.cpp @@ -94,7 +94,7 @@ int main(int argc, char** argv) { // It assumes the outputs are all tensors. for (size_t i = 0; i < method->outputs_size(); i++) { auto output_tensor = output_list[i].toTensor(); - auto data_output = output_tensor.const_data_ptr(); + [[maybe_unused]] auto data_output = output_tensor.const_data_ptr(); for (size_t j = 0; j < output_list[i].toTensor().numel(); ++j) { ET_LOG(Info, "%f", data_output[j]); } diff --git a/third-party/prelude b/third-party/prelude index 4e9e6d50b8..851d3f09c4 160000 --- a/third-party/prelude +++ b/third-party/prelude @@ -1 +1 @@ -Subproject commit 4e9e6d50b8b461564a7e351ff60b87fe59d7e53b +Subproject commit 851d3f09c452937fc5adef27e2c50f7f304f1646