diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index ec597b9c72..335a8930dc 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -f5b99976adcbb01fd71bd0a39ea15bdac6c9e48a +6ca9ae4f8693639c395544327f7e362441a58c79 diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 6900013564..7def3cb318 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -44,13 +44,15 @@ install_pip_dependencies() { } fix_conda_ubuntu_libstdcxx() { + cat /etc/issue # WARNING: This is a HACK from PyTorch core to be able to build PyTorch on 22.04. - # The issue still exists with the latest conda 23.10.0-1 at the time of writing - # (2023/11/16). + # Specifically, ubuntu-20+ all comes lib libstdc++ newer than 3.30+, but anaconda + # is stuck with 3.29. So, remove libstdc++6.so.3.29 as installed by + # https://anaconda.org/anaconda/libstdcxx-ng/files?version=11.2.0 # # PyTorch sev: https://github.com/pytorch/pytorch/issues/105248 # Ref: https://github.com/pytorch/pytorch/blob/main/.ci/docker/common/install_conda.sh - if grep -e "[12][82].04.[623]" /etc/issue >/dev/null; then + if grep -e "2[02].04." /etc/issue >/dev/null; then rm "/opt/conda/envs/py_${PYTHON_VERSION}/lib/libstdc++.so.6" fi } diff --git a/.ci/scripts/utils.sh b/.ci/scripts/utils.sh index 04d3307220..5ba8c57cdc 100644 --- a/.ci/scripts/utils.sh +++ b/.ci/scripts/utils.sh @@ -19,7 +19,7 @@ install_executorch() { which pip # Install executorch, this assumes that Executorch is checked out in the # current directory - pip install . --no-build-isolation + pip install . --no-build-isolation -v # Just print out the list of packages for debugging pip list } diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 77725298a2..d93e9e9cef 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -2,3 +2,5 @@ ciflow_push_tags: - ciflow/nightly - ciflow/trunk +- ciflow/binaries +- ciflow/binaries/all diff --git a/.github/workflows/_unittest.yml b/.github/workflows/_unittest.yml index b03136e380..a149fde3aa 100644 --- a/.github/workflows/_unittest.yml +++ b/.github/workflows/_unittest.yml @@ -33,8 +33,11 @@ jobs: conda activate "${CONDA_ENV}" BUILD_TOOL=${{ matrix.build-tool }} + # Setup MacOS dependencies as there is no Docker support on MacOS atm - PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_PYBIND=ON bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}" + PYTHON_EXECUTABLE=python \ + EXECUTORCH_BUILD_PYBIND=ON \ + .ci/scripts/setup-linux.sh "${BUILD_TOOL}" # Run pytest with coverage pytest -n auto --cov=./ --cov-report=xml @@ -59,8 +62,13 @@ jobs: BUILD_TOOL=${{ matrix.build-tool }} bash .ci/scripts/setup-conda.sh + # Setup MacOS dependencies as there is no Docker support on MacOS atm - PYTHON_EXECUTABLE=python ${CONDA_RUN} EXECUTORCH_BUILD_PYBIND=ON bash .ci/scripts/setup-macos.sh "${BUILD_TOOL}" + PYTHON_EXECUTABLE=python \ + EXECUTORCH_BUILD_PYBIND=ON \ + CMAKE_ARGS="-DEXECUTORCH_BUILD_COREML=ON -DEXECUTORCH_BUILD_MPS=ON -DEXECUTORCH_BUILD_XNNPACK=ON" \ + ${CONDA_RUN} --no-capture-output \ + .ci/scripts/setup-macos.sh "${BUILD_TOOL}" # Run pytest with coverage ${CONDA_RUN} pytest -n auto --cov=./ --cov-report=xml diff --git a/.github/workflows/app-build.yml b/.github/workflows/app-build.yml deleted file mode 100644 index 6a5b02b4fc..0000000000 --- a/.github/workflows/app-build.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: Build ExecuTorch demo apps - -on: - push: - branches: - - main - - release/* - pull_request: - paths: - - .ci/docker/** - - .github/workflows/app-build.yml - - install_requirements.sh - - backends/apple/** - - build/build_apple_frameworks.sh - - build/test_ios_ci.sh - - examples/demo-apps/** - - extension/module/** - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - test-demo-ios: - name: test-demo-ios - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - with: - runner: macos-latest-xlarge - python-version: '3.11' - submodules: 'true' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 90 - script: | - WORKSPACE=$(pwd) - pushd "${WORKSPACE}/pytorch/executorch" - BUILD_TOOL=cmake - - bash .ci/scripts/setup-conda.sh - # Setup MacOS dependencies as there is no Docker support on MacOS atm - GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/setup-macos.sh "${BUILD_TOOL}" - # Build and test iOS Demo App - PYTHON_EXECUTABLE=python ${CONDA_RUN} bash build/test_ios_ci.sh - popd diff --git a/.github/workflows/apple.yml b/.github/workflows/apple.yml new file mode 100644 index 0000000000..0a4a06aa70 --- /dev/null +++ b/.github/workflows/apple.yml @@ -0,0 +1,92 @@ +name: Apple + +on: + push: + branches: + - main + - release/* + pull_request: + paths: + - .ci/docker/** + - .github/workflows/app-build.yml + - install_requirements.sh + - backends/apple/** + - build/build_apple_frameworks.sh + - build/create_frameworks.sh + - build/test_ios_ci.sh + - examples/demo-apps/** + - extension/apple/** + - extension/module/** + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + test-demo-ios: + name: test-demo-ios + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + runner: macos-latest-xlarge + python-version: '3.11' + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + WORKSPACE=$(pwd) + pushd "${WORKSPACE}/pytorch/executorch" + BUILD_TOOL=cmake + + .ci/scripts/setup-conda.sh + + # Setup MacOS dependencies as there is no Docker support on MacOS atm + GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + .ci/scripts/setup-macos.sh "${BUILD_TOOL}" + + # Build and test iOS Demo App + PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + build/test_ios_ci.sh + + popd + + build-frameworks-ios: + name: build-frameworks-ios + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + runner: macos-latest-xlarge + python-version: '3.11' + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + upload-artifact: executorch.zip + timeout: 90 + script: | + WORKSPACE=$(pwd) + pushd "${WORKSPACE}/pytorch/executorch" + BUILD_TOOL=cmake + VERSION="0.1.0" + OUTPUT="executorch-${VERSION}" + + .ci/scripts/setup-conda.sh + + # Setup MacOS dependencies as there is no Docker support on MacOS atm + GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + .ci/scripts/setup-macos.sh "${BUILD_TOOL}" + + # Install CoreML Backend Requirements + PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + backends/apple/coreml/scripts/install_requirements.sh + + # Install MPS Backend Requirements + PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + backends/apple/mps/install_requirements.sh + + # Build iOS Frameworks + PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + build/build_apple_frameworks.sh --output="${OUTPUT}" --coreml --mps --portable --xnnpack + + # Bundle iOS Frameworks + cp LICENSE "${OUTPUT}" + zip -r "${RUNNER_TEMP}/artifacts/${OUTPUT}.zip" "${OUTPUT}" + + popd diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml new file mode 100644 index 0000000000..a2f86b219f --- /dev/null +++ b/.github/workflows/build-wheels-linux.yml @@ -0,0 +1,57 @@ +# From https://github.com/pytorch/test-infra/wiki/Using-Nova-Reusable-Build-Workflows +name: Build Linux Wheels + +on: + pull_request: + paths: + - build/packaging/** + - .github/workflows/build-wheels-linux.yml + push: + branches: + - nightly + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - ciflow/binaries/* + workflow_dispatch: + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cuda: disabled + with-rocm: disabled + + build: + needs: generate-matrix + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/executorch + pre-script: build/packaging/pre_build_script.sh + post-script: build/packaging/post_build_script.sh + smoke-test-script: build/packaging/smoke_test.py + package-name: executorch + name: ${{ matrix.repository }} + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: ${{ matrix.repository }} + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + post-script: ${{ matrix.post-script }} + package-name: ${{ matrix.package-name }} + smoke-test-script: ${{ matrix.smoke-test-script }} + trigger-event: ${{ github.event_name }} diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml new file mode 100644 index 0000000000..dbc74433ff --- /dev/null +++ b/.github/workflows/build-wheels-m1.yml @@ -0,0 +1,58 @@ +# From https://github.com/pytorch/test-infra/wiki/Using-Nova-Reusable-Build-Workflows +name: Build M1 Wheels + +on: + pull_request: + paths: + - build/packaging/** + - .github/workflows/build-wheels-m1.yml + push: + branches: + - nightly + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - ciflow/binaries/* + workflow_dispatch: + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: macos-arm64 + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cuda: disabled + with-rocm: disabled + + build: + needs: generate-matrix + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/executorch + pre-script: build/packaging/pre_build_script.sh + post-script: build/packaging/post_build_script.sh + smoke-test-script: build/packaging/smoke_test.py + package-name: executorch + name: ${{ matrix.repository }} + uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main + with: + repository: ${{ matrix.repository }} + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + post-script: ${{ matrix.post-script }} + package-name: ${{ matrix.package-name }} + runner-type: macos-m1-stable + smoke-test-script: ${{ matrix.smoke-test-script }} + trigger-event: ${{ github.event_name }} diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 7e4dba0b84..7f70f0cfef 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -206,7 +206,9 @@ jobs: # build module for executorch.extension.pybindings.portable_lib BUILD_TOOL=${{ matrix.build-tool }} - PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_PYBIND=ON bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}" + PYTHON_EXECUTABLE=python \ + EXECUTORCH_BUILD_PYBIND=ON \ + bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}" # see if we can import the module successfully python -c "from executorch.extension.pybindings import portable_lib; print('success!')" diff --git a/.gitmodules b/.gitmodules index ce839f0b10..e21abf3bae 100644 --- a/.gitmodules +++ b/.gitmodules @@ -55,3 +55,7 @@ [submodule "backends/vulkan/third-party/Vulkan-Headers"] path = backends/vulkan/third-party/Vulkan-Headers url = https://github.com/KhronosGroup/Vulkan-Headers +[submodule "third-party/lm-evaluation-harness"] + path = third-party/lm-evaluation-harness + url = https://github.com/EleutherAI/lm-evaluation-harness + branch = v0.4.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 185214c275..778b7886cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,7 +98,7 @@ endif() # data into sections so they can be properly gc'd. -s: strip symbol. # -fno-exceptions -fno-rtti: disables exceptions and runtime type. set(CMAKE_CXX_FLAGS_RELEASE - "-O2 -ffunction-sections -fdata-sections -fno-exceptions -fno-rtti") + "-ffunction-sections -fdata-sections -fno-exceptions -fno-rtti") if(NOT APPLE) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -s") endif() @@ -206,6 +206,13 @@ else() set(CMAKE_TOOLCHAIN_IOS OFF) endif() +# Detect if an Android toolchain is set. +if(CMAKE_TOOLCHAIN_FILE MATCHES ".*android\.toolchain\.cmake$") + set(CMAKE_TOOLCHAIN_ANDROID ON) +else() + set(CMAKE_TOOLCHAIN_ANDROID OFF) +endif() + # EXECUTORCH_BUILD_HOST_TARGETS: Option to control the building of host-only # tools like `flatc`, along with example executables like `executor_runner` and # libraries that it uses, like `gflags`. Disabling this can be helpful when @@ -328,11 +335,6 @@ if(EXECUTORCH_BUILD_GTESTS) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/googletest) endif() -option(EXECUTORCH_BUILD_ANDROID_JNI "Build Android JNI" OFF) -if(EXECUTORCH_BUILD_ANDROID_JNI) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/android) -endif() - if(EXECUTORCH_BUILD_SDK) set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON @@ -340,6 +342,10 @@ if(EXECUTORCH_BUILD_SDK) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sdk) endif() +if(EXECUTORCH_BUILD_EXTENSION_APPLE) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/apple) +endif() + if(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/data_loader) endif() @@ -361,6 +367,12 @@ if(EXECUTORCH_BUILD_VULKAN) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan) endif() +option(EXECUTORCH_BUILD_ANDROID_JNI "Build Android JNI" OFF) +if(EXECUTORCH_BUILD_ANDROID_JNI) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/runner) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/android) +endif() + option(EXECUTORCH_BUILD_QNN "Build the backends/qualcomm directory" OFF) if(EXECUTORCH_BUILD_QNN) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/qualcomm) diff --git a/backends/apple/coreml/CMakeLists.txt b/backends/apple/coreml/CMakeLists.txt index b9cc309ff3..167a16f1ba 100644 --- a/backends/apple/coreml/CMakeLists.txt +++ b/backends/apple/coreml/CMakeLists.txt @@ -4,6 +4,10 @@ cmake_minimum_required(VERSION 3.19) project(executorch_coreml_backend) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + # Source root directory for executorch. if(NOT EXECUTORCH_ROOT) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) @@ -117,8 +121,3 @@ set( TARGET coremldelegate APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-receiver-expr" ) - -set_property( - TARGET coremldelegate - PROPERTY CXX_STANDARD 17 -) diff --git a/backends/apple/coreml/quantizer/coreml_quantizer.py b/backends/apple/coreml/quantizer/coreml_quantizer.py new file mode 100644 index 0000000000..ad596f7dfc --- /dev/null +++ b/backends/apple/coreml/quantizer/coreml_quantizer.py @@ -0,0 +1,7 @@ +# Copyright © 2024 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +from coremltools.optimize.torch.quantization._coreml_quantizer import ( # noqa: FLAKE8 F401 + CoreMLQuantizer, +) diff --git a/backends/apple/coreml/runtime/delegate/asset.h b/backends/apple/coreml/runtime/delegate/asset.h index c838cf8d6d..e6cd371cd7 100644 --- a/backends/apple/coreml/runtime/delegate/asset.h +++ b/backends/apple/coreml/runtime/delegate/asset.h @@ -8,6 +8,7 @@ #import #import +#import #import #import diff --git a/backends/apple/coreml/scripts/generate_test_models.sh b/backends/apple/coreml/scripts/generate_test_models.sh index be71b56e31..7beca63726 100755 --- a/backends/apple/coreml/scripts/generate_test_models.sh +++ b/backends/apple/coreml/scripts/generate_test_models.sh @@ -23,8 +23,8 @@ cd "$EXECUTORCH_ROOT_PATH" MODELS=("add" "mul" "mv3") for MODEL in "${MODELS[@]}" do - # TODO: Don't use the script in examples directory. - python3 -m examples.apple.coreml.scripts.export_and_delegate --model_name "$MODEL" --save_processed_bytes + # TODO: Don't use the script in examples directory. + python3 -m examples.apple.coreml.scripts.export --model_name "$MODEL" --save_processed_bytes mv -f "$MODEL""_coreml_all.pte" "$COREML_DIR_PATH/runtime/test/models" mv -f "$MODEL""_coreml_all.bin" "$COREML_DIR_PATH/runtime/test/models" done diff --git a/backends/apple/coreml/setup.md b/backends/apple/coreml/setup.md index 86afc50f00..c01f6e2d23 100644 --- a/backends/apple/coreml/setup.md +++ b/backends/apple/coreml/setup.md @@ -11,18 +11,18 @@ This is a tutorial for setting up the Core ML backend. ``` cd executorch -./backends/apple/coreml/scripts/install_requirements.sh +./backends/apple/coreml/scripts/install_requirements.sh -``` +``` -3. Run the example script to validate that the **Core ML** backend is set up correctly. +3. Run the example script to validate that the **Core ML** backend is set up correctly. ``` cd executorch # Saves add_coreml_all.pte in the current directory if successful. -python3 -m examples.apple.coreml.scripts.export_and_delegate --model_name add +python3 -m examples.apple.coreml.scripts.export --model_name add ``` @@ -66,6 +66,6 @@ coreml_backend.xcframework - Accelerate.framework - CoreML.framework - libsqlite3.tbd -``` +``` -6. The target could now run a **Core ML** delegated **Program**. +6. The target could now run a **Core ML** delegated **Program**. diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py new file mode 100644 index 0000000000..67eee3593f --- /dev/null +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -0,0 +1,112 @@ +# Copyright © 2024 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +from typing import Tuple + +import numpy as np +import pytest + +import torch + +from coremltools.optimize.torch.quantization.quantization_config import ( + LinearQuantizerConfig, + QuantizationScheme, +) + +from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) + + +class TestCoreMLQuantizer: + @staticmethod + def quantize_and_compare( + model, + example_inputs: Tuple[torch.Tensor], + quantization_type: str, + ) -> None: + assert quantization_type in {"PTQ", "QAT"} + + pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs) + + quantization_config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": QuantizationScheme.symmetric, + "milestones": [0, 0, 10, 10], + "activation_dtype": torch.quint8, + "weight_dtype": torch.qint8, + "weight_per_channel": True, + } + } + ) + quantizer = CoreMLQuantizer(quantization_config) + + if quantization_type == "PTQ": + prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) + elif quantization_type == "QAT": + prepared_graph = prepare_qat_pt2e(pre_autograd_aten_dialect, quantizer) + + prepared_graph(*example_inputs) + converted_graph = convert_pt2e(prepared_graph) + + model_output = model(*example_inputs).detach().numpy() + quantized_output = converted_graph(*example_inputs).detach().numpy() + np.testing.assert_allclose(quantized_output, model_output, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("quantization_type", ("PTQ", "QAT")) + def test_conv_relu(self, quantization_type): + SHAPE = (1, 3, 256, 256) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, padding=1 + ) + self.relu = torch.nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = self.conv(x) + return self.relu(a) + + model = Model() + + example_inputs = (torch.randn(SHAPE),) + self.quantize_and_compare( + model, + example_inputs, + quantization_type, + ) + + @pytest.mark.parametrize("quantization_type", ("PTQ", "QAT")) + def test_linear(self, quantization_type): + SHAPE = (1, 5) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + model = Model() + + example_inputs = (torch.randn(SHAPE),) + self.quantize_and_compare( + model, + example_inputs, + quantization_type, + ) + + +if __name__ == "__main__": + test_runner = TestCoreMLQuantizer() + test_runner.test_conv_relu("PTQ") + test_runner.test_linear("QAT") diff --git a/backends/arm/arm_quantizer.py b/backends/arm/arm_quantizer.py new file mode 100644 index 0000000000..1062c62428 --- /dev/null +++ b/backends/arm/arm_quantizer.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +import copy +import functools + +from typing import Any, Callable, Dict, List, Optional, Set + +import torch +import torch._dynamo as torchdynamo +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) + +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer + +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _convert_scalars_to_attrs, + OP_TO_ANNOTATOR, + OperatorConfig, + OperatorPatternType, + propagate_annotation, + QuantizationConfig, +) + +from torch.fx import Node + + +__all__ = [ + "XNNPACKQuantizer", + "get_symmetric_quantization_config", +] + + +def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) + gm.graph.eliminate_dead_code() + return gm.graph + + +def _get_linear_patterns(input_size: List[int]): + in_channels = input_size[-1] + out_channels = 8 # hard coding but this should not matter + weight = torch.ones((out_channels, in_channels)) + bias = torch.ones((out_channels,)) + act = torch.ones(input_size) + + def linear_op(act, weight, bias=None): + return F.linear(act, weight, bias) + + pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) + pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) + return [pattern_w_bias, pattern_wo_bias] + + +def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: + supported_operators: Dict[str, List[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]], + "adaptive_avg_pool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } + return copy.deepcopy(supported_operators) + + +def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: + supported_config_and_operators: List[OperatorConfig] = [] + for quantization_config in [ + get_symmetric_quantization_config(), + get_symmetric_quantization_config(is_qat=True), + get_symmetric_quantization_config(is_per_channel=True), + get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + ]: + ops = _supported_symmetric_quantized_operators() + for pattern_list in ops.values(): + supported_config_and_operators.append( + OperatorConfig(quantization_config, pattern_list) + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_symmetric_quantization_config( + is_per_channel: bool = False, + is_qat: bool = False, + is_dynamic: bool = False, + act_qmin: int = -128, + act_qmax: int = 127, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + MinMaxObserver + ) + if is_qat: + # TODO: qat + per channel? + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if weight_qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, + None, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> List[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + names = [n[len("L['self'].") :] for n, klass in nn_module_stack.values()] + return module_name in names + + return module_name_filter + + +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [t for _, t in nn_module_stack.values()] + return tp in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: List[Callable], module_name_list: List[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + +class XNNPACKQuantizer(Quantizer): + supported_config_and_operators = _get_supported_config_and_operators() + STATIC_QAT_ONLY_OPS = [ + "conv_bn_relu", + "conv_bn", + ] + + # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops + STATIC_OPS = [ + "linear_relu", + "linear", + "conv_relu", + "conv", + "adaptive_avg_pool2d", + # TODO: move this to BoltNNQuantizer? + "gru_io_only", + "max_pool2d", + "add_relu", + "add", + "mul_relu", + "mul", + "cat", + ] + + DYNAMIC_OPS = [ + "linear", + ] + + def __init__(self): + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_config: Dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: + op_configs: Set[QuantizationConfig] = set({}) + for spec, _ in cls.supported_config_and_operators: + op_configs.add(spec) + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> List[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: + self.global_config = quantization_config + return self + + def set_operator_type( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig, + ) -> XNNPACKQuantizer: + self.operator_type_config[operator_type] = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: QuantizationConfig + ): + """Set quantization_config for a submodule with type: `module_type`, for example: + quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator + patterns in the submodule with this module type with the given `quantization_config` + """ + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + assert ( + quantization_config is not None + ), " quantization_config == None is not supported yet" + self.module_name_config[module_name] = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + # hacked for handling dynamic linear quant. will fix later. + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for op in self.STATIC_QAT_ONLY_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + for op in self.STATIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_all_dynamic_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + for op in self.DYNAMIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_static_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_dynamic_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/backends/arm/arm_quantizer_utils.py b/backends/arm/arm_quantizer_utils.py new file mode 100644 index 0000000000..7316e0fbad --- /dev/null +++ b/backends/arm/arm_quantizer_utils.py @@ -0,0 +1,1031 @@ +import itertools +import operator +from dataclasses import dataclass +from typing import Callable, Dict, List, NamedTuple, Optional + +import torch +import torch.nn.functional as F +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.pt2e.utils import ( + _conv1d_bn_example_inputs, + _conv2d_bn_example_inputs, + get_aten_graph_module, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) + +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +__all__ = [ + "OperatorConfig", + "OperatorPatternType", + "QuantizationConfig", + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", + "OP_TO_ANNOTATOR", + "propagate_annotation", +] + + +# In the absence of better name, just winging it with QuantizationConfig +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec] + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this + is_qat: bool = False + + +OperatorPatternType = List[Callable] +OperatorPatternType.__module__ = ( + "torch.ao.quantization.quantizer.xnnpack_quantizer_utils" +) + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[List[List[Node]]], +] +OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {} + + +def register_annotator(op: str): + def decorator(annotator: AnnotatorType): + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +class OperatorConfig(NamedTuple): + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + config: QuantizationConfig + operators: List[OperatorPatternType] + + +def _is_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: List[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.input_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.input_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.output_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.output_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.weight is None: + return None + quantization_spec: QuantizationSpec = quantization_config.weight + if quantization_spec.qscheme not in [ + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + ]: + raise ValueError( + f"Unsupported quantization_spec {quantization_spec} for weight" + ) + return quantization_spec + + +def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.bias is None: + return None + quantization_spec: QuantizationSpec = quantization_config.bias + assert ( + quantization_spec.dtype == torch.float + ), "Only float dtype for bias is supported for bias right now" + return quantization_spec + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + _annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + _annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + _annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv") +def _annotate_conv( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_relu") +def _annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = n + maybe_conv_node = n.args[0] + if ( + not isinstance(maybe_conv_node, Node) + or maybe_conv_node.op != "call_function" + or maybe_conv_node.target + not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ] + ): + continue + conv_node = maybe_conv_node + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [relu_node, conv_node, conv_node.args[1]] + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_bn") +def _annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + """ + Find conv + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) + + +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + """ + Find conv + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) + + +def _get_pattern(conv_fn: Callable, relu_is_inplace: bool, has_relu: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _WrapperModule(_conv_bn) + + +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]], + has_relu: bool, +) -> List[List[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. + + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + combinations = [ + (F.conv1d, _conv1d_bn_example_inputs), + (F.conv2d, _conv2d_bn_example_inputs), + ] + + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: + pattern = _get_pattern(conv_fn, relu_is_inplace, has_relu) + pattern = get_aten_graph_module(pattern, example_inputs, is_cuda) + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("gru_io_only") +def _annotate_gru_io_only( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) + gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) + annotated_partitions = [] + for gru_partition in gru_partitions: + annotated_partitions.append(gru_partition.nodes) + output_nodes = gru_partition.output_nodes + input_nodes = gru_partition.input_nodes + # skip annotation if it is already annotated + if _is_annotated(input_nodes + output_nodes): + continue + # inside each GRU partition, we should be able to annotate each linear + # subgraph + input_act = input_nodes[0] + input_act_user = next(iter(input_act.users.keys())) + assert isinstance(input_act, Node) + assert isinstance(input_act_user, Node) + input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + hidden_state = input_nodes[1] + hidden_state_user = next(iter(hidden_state.users.keys())) + assert isinstance(hidden_state, Node) + assert isinstance(hidden_state_user, Node) + hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + hidden_state: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + assert len(output_nodes) == 2, "expecting GRU to have two outputs" + for output in output_nodes: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + nodes_to_mark_annotated = list(gru_partition.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + return annotated_partitions + + +@register_annotator("max_pool2d") +def _annotate_max_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + module_partitions = get_source_partitions( + gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn + ) + maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for maxpool_partition in maxpool_partitions: + annotated_partitions.append(maxpool_partition.nodes) + output_node = maxpool_partition.output_nodes[0] + maxpool_node = None + for n in maxpool_partition.nodes: + if n.target == torch.ops.aten.max_pool2d.default: + maxpool_node = n + assert ( + maxpool_node is not None + ), "XNNPACKQuantizer only works with torch.ops.aten.max_pool2d.default, " + "please make sure you are exporting the model correctly" + if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item] + continue + + input_act = maxpool_node.args[0] # type: ignore[union-attr] + assert isinstance(input_act, Node) + + # only annotate maxpool when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + continue + # input and output of maxpool will share quantization parameter with input of maxpool + act_qspec = SharedQuantizationSpec(input_act) + # act_qspec = get_act_qspec(quantization_config) + maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] + input_qspec_map={ + input_act: act_qspec, + }, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + """Always annotate adaptive_avg_pool2d op""" + module_partitions = get_source_partitions( + gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn + ) + partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for partition in partitions: + pool_node = partition.output_nodes[0] + if ( + pool_node.op != "call_function" + or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default + ): + raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") + + if _is_annotated([pool_node]): + continue + + annotated_partitions.append(partition.nodes) + input_act = pool_node.args[0] + assert isinstance(input_act, Node) + + # only annotate input output sharing operator + # when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + input_act_qspec = get_input_act_qspec(quantization_config) + else: + input_act_qspec = SharedQuantizationSpec(input_act) + + # output sharing with input + output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) + pool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the node + since histc op (in HistogramObserver) only works for values up to certain upper bound + """ + if node.op == "get_attr": + tensor = getattr(gm, node.target) # type: ignore[arg-type] + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + + +@register_annotator("add_relu") +def _annotate_add_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + fused_partitions = find_sequential_partitions( + gm, [torch.add, torch.nn.ReLU], filter_fn=filter_fn + ) + annotated_partitions = [] + for fused_partition in fused_partitions: + add_partition, relu_partition = fused_partition + annotated_partitions.append(add_partition.nodes + relu_partition.nodes) + if len(relu_partition.output_nodes) > 1: + raise ValueError("Relu partition has more than one output node") + relu_node = relu_partition.output_nodes[0] + if len(add_partition.output_nodes) > 1: + raise ValueError("add partition has more than one output node") + add_node = add_partition.output_nodes[0] + + if _is_annotated([relu_node, add_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("add") +def _annotate_add( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + add_partitions = get_source_partitions( + gm.graph, [operator.add, torch.add, operator.iadd], filter_fn + ) + add_partitions = list(itertools.chain.from_iterable(add_partitions.values())) + annotated_partitions = [] + for add_partition in add_partitions: + annotated_partitions.append(add_partition.nodes) + add_node = add_partition.output_nodes[0] + if _is_annotated([add_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("mul_relu") +def _annotate_mul_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + fused_partitions = find_sequential_partitions( + gm, [torch.mul, torch.nn.ReLU], filter_fn=filter_fn + ) + annotated_partitions = [] + for fused_partition in fused_partitions: + mul_partition, relu_partition = fused_partition + annotated_partitions.append(mul_partition.nodes + relu_partition.nodes) + if len(relu_partition.output_nodes) > 1: + raise ValueError("Relu partition has more than one output node") + relu_node = relu_partition.output_nodes[0] + if len(mul_partition.output_nodes) > 1: + raise ValueError("mul partition has more than one output node") + mul_node = mul_partition.output_nodes[0] + + if _is_annotated([relu_node, mul_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("mul") +def _annotate_mul( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + mul_partitions = get_source_partitions( + gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn + ) + mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values())) + annotated_partitions = [] + for mul_partition in mul_partitions: + annotated_partitions.append(mul_partition.nodes) + mul_node = mul_partition.output_nodes[0] + if _is_annotated([mul_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +# TODO: remove Optional in return type, fix annotated_partitions logic +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + cat_node = cat_partition.output_nodes[0] + if _is_annotated([cat_node]): + continue + + if cat_node.target != torch.ops.aten.cat.default: + # TODO: change this to AnnotationException + raise Exception( + f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" + " please check if you are calling the correct capture API" + ) + + annotated_partitions.append(cat_partition.nodes) + + input_act_qspec = get_input_act_qspec(quantization_config) + inputs = cat_node.args[0] + + input_qspec_map = {} + input_act0 = inputs[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec + + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) + for input_act in inputs[1:]: + input_qspec_map[input_act] = shared_with_input0_qspec + + output_act_qspec = shared_with_input0_qspec + + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable) -> bool: + return op in [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + # TODO: remove? + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# TODO: make the list of ops customizable +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py index 959b2034f2..cc49e5c382 100644 --- a/backends/arm/operators/op_addmm.py +++ b/backends/arm/operators/op_addmm.py @@ -65,7 +65,15 @@ def define_node( stride_attr = [1, 1] dilation_attr = [1, 1] - input_zp = -128 if is_quant_node else 0 + input_zp = 0 + if is_quant_node: + input_node = node.all_input_nodes[1] + # rank > 2 linear layer + if input_node.target == exir_ops.edge.aten.view_copy.default: + quant_node = input_node.all_input_nodes[0] + else: + quant_node = input_node + input_zp = get_quant_node_args(quant_node)[1] attr.ConvAttribute( pad=pad_attr, stride=stride_attr, diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index f935922bc2..597839021f 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -38,9 +38,14 @@ def process_placeholder( if consumer_node.target in dq_q_ops: _, weight_node_scale, weight_node_zp, _, _, _ = getNodeArgs(consumer_node) + int8_max = np.iinfo(np.int8).max + int8_min = np.iinfo(np.int8).min parameter_values_quantized = ( - (parameter_values / weight_node_scale.number) + weight_node_zp.number - ).astype(np.int8) + ((parameter_values / weight_node_scale.number) + weight_node_zp.number) + .round() + .clip(int8_min, int8_max) + .astype(np.int8) + ) tosa_graph.addConst( inputs[0].shape, ts.DType.INT8, @@ -63,8 +68,10 @@ def process_placeholder( weight_node_scale, weight_node_zp = get_quant_node_args(weight_node) bias_values_quantized = ( - parameter_values / (input_node_scale * weight_node_scale) - ).astype(np.int32) + (parameter_values / (input_node_scale * weight_node_scale)) + .round() + .astype(np.int32) + ) tosa_graph.addConst( inputs[0].shape, @@ -86,8 +93,8 @@ def process_placeholder( weight_node_scale, _ = get_quant_node_args(weight_node) bias_scales = input_node_scale * weight_node_scale - parameter_values_quantized = (parameter_values / bias_scales).astype( - np.int32 + parameter_values_quantized = ( + (parameter_values / bias_scales).round().astype(np.int32) ) tosa_graph.addConst( diff --git a/backends/arm/test/arm_tosa_reference.py b/backends/arm/test/arm_tosa_reference.py index 4a6bbf44f9..0abec37cc5 100644 --- a/backends/arm/test/arm_tosa_reference.py +++ b/backends/arm/test/arm_tosa_reference.py @@ -41,8 +41,6 @@ SUPPORTED_BI_TEST_LIST = [ "simple_add", "simple_add_broadcast", - "simple_linear", - "simple_linear_rank4", "simple_conv2d_3x3_1x3x256x256_stride1", "simple_conv2d_1x1_1x2x128x128_stride1", "simple_conv2d_2x2_1x1x14x14_stride2", @@ -250,9 +248,8 @@ def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901 # Need to dequant back to FP32 for running comparison with Torch output if profile is TosaProfile.BI: tosa_output = ( - np.round(tosa_output - output_quantization_zp) - * output_quantization_scale - ) + tosa_output - output_quantization_zp + ) * output_quantization_scale ## Read the Torch Output torch_file = open(TORCH_OUT_PATH + "/torch_output.npy", "rb") diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index 961507a863..bd33ee405a 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -40,8 +40,8 @@ def test_mv2_tosa_MI(self): backend=ArmBackendSelector.TOSA, ) .export() - .check(list(self.all_operators)) .to_edge() + .check(list(self.all_operators)) .partition() .to_executorch() .run_method() @@ -59,8 +59,8 @@ def test_mv2_tosa_BI(self): ) .quantize() .export() - .check(list(self.all_operators)) .to_edge() + .check(list(self.all_operators)) .partition() .to_executorch() .run_method() @@ -78,8 +78,8 @@ def test_mv2_u55_BI(self): ) .quantize() .export() - .check(list(self.all_operators)) .to_edge() + .check(list(self.all_operators)) .partition() .to_executorch() ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index a7fa07dd05..0791bab5eb 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -88,7 +88,7 @@ def _test_add_tosa_BI_pipeline( .to_executorch() ) if TOSA_REF_MODEL_INSTALLED: - tester.run_method().compare_outputs() + tester.run_method().compare_outputs(qtol=1) else: logger.warning( "TOSA ref model tool not installed, skip numerical correctness tests" @@ -118,8 +118,6 @@ def test_add_tosa_MI(self): test_data = (torch.randn(4, 4, 4),) self._test_add_tosa_MI_pipeline(self.Add(), test_data) - # TODO: Will this type of parametrization be supported? pytest seem - # have issue with it. @parameterized.expand( [ (torch.ones(5),), # test_data diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py new file mode 100644 index 0000000000..8cbd41e012 --- /dev/null +++ b/backends/arm/test/ops/test_linear.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import shutil +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test.test_models import TosaProfile +from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester +from parameterized import parameterized + +# TODO: fixme! These globs are a temporary workaround. Reasoning: +# Running the jobs in _unittest.yml will not work since that environment don't +# have the vela tool, nor the tosa_reference_model tool. Hence, we need a way to +# run what we can in that env temporarily. Long term, vela and tosa_reference_model +# should be installed in the CI env. +TOSA_REF_MODEL_INSTALLED = shutil.which("tosa_reference_model") +VELA_INSTALLED = shutil.which("vela") + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +seed = 42 +torch.manual_seed(seed) +test_data_suite = [ + # (test_name, test_data, out_features) + ( + "model_linear_rank1_zeros", + torch.zeros(10, 10), + 10, + ), + ( + "model_linear_rank1_ones", + torch.ones(10, 10), + 10, + ), + ( + "model_linear_rank1_negative_ones", + torch.ones(10, 10) * (-1), + 10, + ), + ( + "model_linear_rank1_rand", + torch.rand(10, 10), + 10, + ), + ( + "model_linear_rank1_negative_large_rand", + torch.rand(10, 10) * (-100), + 10, + ), + ( + "model_linear_rank1_large_randn", + torch.randn(10, 10) * 100, + 10, + ), + ( + "model_linear_rank4_zeros", + torch.zeros(5, 10, 25, 20), + 30, + ), + ( + "model_linear_rank4_ones", + torch.ones(5, 10, 25, 20), + 30, + ), + ( + "model_linear_rank4_negative_ones", + torch.ones(5, 10, 25, 20) * (-1), + 30, + ), + ( + "model_linear_rank4_rand", + torch.rand(5, 10, 25, 20), + 30, + ), + ( + "model_linear_rank4_negative_large_rand", + torch.rand(5, 10, 25, 20) * (-100), + 30, + ), + ( + "model_linear_rank4_large_randn", + torch.randn(5, 10, 25, 20) * 100, + 30, + ), +] + + +class TestLinear(unittest.TestCase): + class Linear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int = 3, + bias: bool = True, + ): + super().__init__() + self.fc = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + + def forward(self, x): + return self.fc(x) + + def _test_linear_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + tester = ( + ArmTester( + module, + inputs=test_data, + profile=TosaProfile.MI, + backend=ArmBackendSelector.TOSA, + ) + .export() + .check_count({"torch.ops.aten.addmm.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + if TOSA_REF_MODEL_INSTALLED: + tester.run_method().compare_outputs() + else: + logger.warning( + "TOSA ref model tool not installed, skip numerical correctness tests" + ) + + def _test_linear_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + tester = ( + ArmTester( + module, + inputs=test_data, + profile=TosaProfile.BI, + backend=ArmBackendSelector.TOSA, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.addmm.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + if TOSA_REF_MODEL_INSTALLED: + tester.run_method().compare_outputs(qtol=True) + else: + logger.warning( + "TOSA ref model tool not installed, skip numerical correctness tests" + ) + + def _test_linear_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + inputs=test_data, + profile=TosaProfile.BI, + backend=ArmBackendSelector.ETHOS_U55, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.addmm.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(test_data_suite) + def test_linear_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + out_features: int, + ): + in_features = test_data.shape[-1] + test_data = (test_data,) + self._test_linear_tosa_MI_pipeline( + self.Linear( + in_features=in_features, + out_features=out_features, + ), + test_data, + ) + + @parameterized.expand(test_data_suite) + def test_linear_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + out_features: int, + ): + in_features = test_data.shape[-1] + test_data = (test_data,) + self._test_linear_tosa_BI_pipeline( + self.Linear(in_features=in_features, out_features=out_features), test_data + ) + + @parameterized.expand(test_data_suite) + @unittest.skip("This does not work as of now") + def test_linear_tosa_u55_BI( + self, + test_name: str, + test_data: torch.Tensor, + out_features: int, + ): + in_features = test_data.shape[-1] + test_data = (test_data,) + self._test_linear_tosa_u55_BI_pipeline( + self.Linear( + in_features=in_features, + out_features=out_features, + ), + test_data, + ) diff --git a/backends/arm/test/test_models.py b/backends/arm/test/test_models.py index fbe8806d9b..e7842358da 100644 --- a/backends/arm/test/test_models.py +++ b/backends/arm/test/test_models.py @@ -125,38 +125,6 @@ def __init__(self): def forward(self, x, y): return x + y - @register_test - class simple_linear(torch.nn.Module): - inputs = { - TosaProfile.BI: (torch.rand(1, 2),), - TosaProfile.MI: (torch.rand(1, 2),), - } - - def __init__(self): - super().__init__() - torch.manual_seed(seed) - self.fc = torch.nn.Linear(2, 3) - - def forward(self, x): - x = self.fc(x) - return x - - @register_test - class simple_linear_rank4(torch.nn.Module): - inputs = { - TosaProfile.BI: (torch.rand(5, 10, 25, 20),), - TosaProfile.MI: (torch.rand(5, 10, 25, 20),), - } - - def __init__(self): - super().__init__() - torch.manual_seed(42) - self.fc = torch.nn.Linear(20, 30) - - def forward(self, x): - x = self.fc(x) - return x - """Currenly we compare the quantized result directly with the floating point result, to avoid a noticable precision difference due to wide random numerical distribution, generate small random value range for convolution testing instead for now""" diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 35fd85d66b..6dea3b04b6 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from enum import Enum -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from executorch.backends.arm.arm_backend import ( @@ -15,6 +15,7 @@ from executorch.backends.arm.arm_partitioner import ArmPartitioner from executorch.backends.arm.test.tosautil.tosa_test_utils import ( + QuantizationParams, TosaProfile, TosaTestUtils, ) @@ -32,6 +33,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import ExportedProgram class ArmBackendSelector(Enum): @@ -61,6 +63,7 @@ def __init__( TosaProfile.BI or TosaProfile.MI """ self.tosa_test_util = None + self.is_quantized = profile == TosaProfile.BI if backend == ArmBackendSelector.TOSA: self.tosa_test_util = TosaTestUtils(profile=profile) # The spec below tiggers arm_backend.py to output two files: @@ -119,54 +122,121 @@ def run_method( ), "self.tosa_test_util is not initialized, cannot use run_method()" inputs_to_run = inputs or self.inputs - # TODO: we can't possible need to use all these stages?? - export_stage = self.stages[ - self.stage_name(Export) - ] # this is what XNNpack use to get quant params - toedge_stage = self.stages[ - self.stage_name(ToEdge) - ] # this is what get_input_quantization_params use to get quant params - partition_stage = self.stages[ - self.stage_name(Partition) - ] # this is what tosa_ref_dump_inputs use.... - - # TODO: I'd prefer to use this TOSA buffer instead of output.tosa, - # generated by arm_backend.py. The issue is that we're still depending - # on desc.json, which is created from TosaSerializer class, not from - # the serialized TOSA buffer. Leave this here for review purposes. - # ts_serialized = self._get_serialized_tosa_buffer( # unused - # partition_stage.artifact - # ) - - # This is where the torch reference output is calculated and set - # TODO: This sets self.quantization_scale, which is duplicates - # self.tosa_test_util.quantization.output.scales (?). Fixme. - ( - self.reference_output, - self.quantization_scale, - ) = self._calculate_reference_output(export_stage.artifact, inputs_to_run) - - # Convert the torch inputs to something TOSA ref model can use - tensor_names_and_inputs_np = self.tosa_test_util.convert_inputs_to_tosa( - partition_stage.artifact, toedge_stage.artifact, inputs_to_run + export_stage = self.stages[self.stage_name(Export)] + + (input_names, qp_input) = self._get_input_params(export_stage.artifact) + (output_name, qp_output) = self._get_output_param(export_stage.artifact) + + # Calculate the reference output using the original module or the quant + # module. self.quantization_scale is used by compare_outputs() to + # calculate the tolerance + self.quantization_scale = None if qp_output is None else qp_output.scale + if self.is_quantized: + module_for_ref = self.stages[self.stage_name(Quantize)].artifact + else: + module_for_ref = self.original_module + self.reference_output = self._calculate_reference_output( + module_for_ref, inputs_to_run ) # Run the TOSA ref model to get the output tensor, which will be # compared to the torch output in compare_outputs() self.stage_output = self.tosa_test_util.run_tosa_ref_model( - tensor_names_and_inputs_np + params_input=(input_names, qp_input), + param_output=(output_name, qp_output), + inputs=inputs_to_run, ) return self - def _get_serialized_tosa_buffer(self, partition_stage: Partition) -> bytes: + def _get_input_params( + self, program: ExportedProgram + ) -> Tuple[str, Union[List[QuantizationParams], List[None]]]: """ - This is just a prototype... - Todo: - * The "_0" indicates that there are many lowered modules. Loop it! - * There's probably a better way to get this buffer. An API? Yes, - it seems the serialize stage does this for you... + Get name and optionally quantization parameters for the inputs to this + model. + + Args: + program (ExportedProgram): The program to get input parameters from + Returns: + Tuple[str, Optional[QuantizationParams]]: A tuple containing the + input node names and their quantization parameters. + """ + input_names = [] + # E.g. bias and weights are 'placeholders' as well. This is used to + # get only the use inputs. + usr_inputs = program.graph_signature.user_inputs + for node in program.graph.nodes: + if node.op == "placeholder" and node.name in usr_inputs: + input_names.append(node.name) + continue + + if self.is_quantized: + quant_params = [] + for node in program.graph.nodes: + if ( + node.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + and node.args[0].name in input_names + ): + qp = QuantizationParams( + node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] + ) + quant_params.append(qp) + if len(quant_params) == len( + input_names + ): # break early if we have all the inputs quantized parameters + break + assert len(quant_params) != 0, "Quantization paramerters not found" + return (input_names, quant_params) + else: + return (input_names, len(input_names) * [None]) # return a list of None's + + def _get_output_param( + self, program: ExportedProgram + ) -> Tuple[str, Union[QuantizationParams, None]]: """ - return partition_stage._edge_programs[ - "forward" - ]._graph_module.lowered_module_0.processed_bytes + Get name and optionally quantization parameters for the inputs to this + model. + + Args: + program (ExportedProgram): The program to get output parameters from. + Returns: + Tuple[str, Optional[QuantizationParams]]: A tuple containing the + output node name and its quantization parameters. + """ + output_node = None + for node in program.graph.nodes: + if node.op == "output": + output_node = node + break + + if self.is_quantized: + quant_params = None + for node in program.graph.nodes: + if ( + node.target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + and node == output_node.args[0][0] + ): + quant_params = QuantizationParams( + node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] + ) + break # break early, there's only one output node + assert quant_params is not None, "Quantization paramerters not found" + return (output_node.name, quant_params) + else: + return (output_node.name, None) + + @staticmethod + def _calculate_reference_output( + module: Union[torch.fx.GraphModule, torch.nn.Module], inputs + ) -> torch.Tensor: + """ + Note: I'd prefer to use the base class method here, but since it use the + exported program, I can't. The partitioner stage clears the state_dict + of the exported program, which causes an issue when evaluating the + module. + """ + + return module.forward(*inputs) diff --git a/backends/arm/test/tosautil/tosa_test_utils.py b/backends/arm/test/tosautil/tosa_test_utils.py index 48bbb6311f..ff39761559 100644 --- a/backends/arm/test/tosautil/tosa_test_utils.py +++ b/backends/arm/test/tosautil/tosa_test_utils.py @@ -10,34 +10,27 @@ import subprocess import tempfile -from collections import namedtuple -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import numpy as np import torch -from executorch.backends.arm.test.arm_tosa_reference import ( - get_input_quantization_params, # TODO: remove this dependecy - get_output_quantization_param, # TODO: remove this dependecy - tosa_ref_dump_inputs, # TODO: remove this dependecy -) from executorch.backends.arm.test.test_models import TosaProfile -from executorch.backends.xnnpack.test.tester.tester import Partition, ToEdge logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) class QuantizationParams: - __slots__ = ["zps", "scales"] + __slots__ = ["node_name", "zp", "scale"] - def __init__(self, zps: Union[Dict, List[int]], scales: Union[Dict, List[float]]): - self.zps = zps - self.scales = scales + # todo: zps and scales can be per tensors or per channel => a list?? + def __init__(self, node_name: str, zp: int, scale: float): + self.node_name = node_name # not need I think, but good for error check + self.zp = zp + self.scale = scale -Quantization = namedtuple("Quantization", ["input", "output"]) - """ This class is used to work with TOSA artifacts. """ @@ -55,9 +48,6 @@ def __init__( ) self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" self.profile = profile or TosaProfile.MI - input_quant = QuantizationParams(zps={}, scales={}) - output_quant = QuantizationParams(zps={}, scales={}) - self.quantization = Quantization(input=input_quant, output=output_quant) assert os.path.exists( self.intermediate_path ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" @@ -106,73 +96,11 @@ def dbg_dump_readble_tosa_file(self) -> None: self._run_cmd(cmd_flatc) return - def convert_inputs_to_tosa( - self, - partition_stage: Partition, - toedge_stage: ToEdge, - inputs_to_run: Tuple[torch.Tensor], - ) -> List[Tuple[np.ndarray, str]]: - """ - Convert input tensors to numpy and save them to disk as .npy files. The - TOSA reference model will use these files as input.... - - Args: - partition_stage (Partition): The partition stage. - toedge_stage (ToEdge): The toedge stage. - inputs_to_run (Tuple[torch.Tensor]): The input tensors to convert. - - Returns: - List[Tuple[np.ndarray, str]]: A list of tuples, where each tuple contain - a numpy array and the name of the tensor. - - Todo: - * I'd like to use some other common function instead of - get_input_quantization_params() and - get_output_quantization_param(). - * I'd like to get rid of the call to tosa_ref_dump_inputs as well. - All this function is doing is to convert to numpy and save to disk - """ - - if self.profile == TosaProfile.BI: - # TODO: Unclear to me why we need to pass toedge_stage here. Ideally - # we shouldn't get the quantization params here at all, but rather - # from the Quantizer... - ( - self.quantization.input.scales, - self.quantization.input.zps, - ) = get_input_quantization_params(toedge_stage) - ( - self.quantization.output.scales, - self.quantization.output.zps, - ) = get_output_quantization_param(toedge_stage) - - # TODO: I think it should be possible to get this input data from - # somewhere else. Why do I need to call this just to get a npy file, - # which is just a quantized version of the input..? - np_data_and_tensor_names = tosa_ref_dump_inputs( - partition_stage, - inputs_to_run, - self.intermediate_path, - self.quantization.input.scales, - self.quantization.input.zps, - self.profile, - save_on_disk=False, # If True - this one produces arg0_1.npy, which is just a quant version of the input - # inputs_to_run -> convert to numpy -> do "manual" quantization -> save to arg0_1.npy (TODO: remove this comment) - ) - else: - np_data_and_tensor_names = tosa_ref_dump_inputs( - partition_stage, - inputs_to_run, - self.intermediate_path, - {}, - {}, - save_on_disk=False, - ) - - return np_data_and_tensor_names - def run_tosa_ref_model( - self, tensor_names_and_inputs: List[Tuple[np.array, str]] + self, + params_input: Tuple[List[str], List[QuantizationParams]], + param_output: Tuple[str, QuantizationParams], + inputs: Tuple[torch.Tensor], ) -> torch.Tensor: """ Run TOSA reference model using the tosa_refence_model program. @@ -184,18 +112,20 @@ def run_tosa_ref_model( These two files are created by arm_backend.py as part of partition stage - 3. An IFM file containing input data, saved as .npy. This file is - created by tosa_ref_dump_inputs() - All these files are saved on disk in self.intermediate_path. Args: - tensor_names_and_inputs (List[Tuple[np.array, str]]): A list of tuples - where each tuple contains inputs (as numpy array) and the name of - the tensor. + params_input (Tuple[List[str], List[QuantizationParams]]): A tuple + containing a list of input node names and a list of their + quantization parameters (if model is quantized). + param_output (Tuple[str, QuantizationParams]): A tuple containing + the output node name and its quantization parameters (if + model is quantized). + inputs (Tuple[torch.Tensor]): The input data to run the TOSA Returns: - torch.Tensor: The output of the TOSA reference model, as a torch tensor. + torch.Tensor: The output of the TOSA reference model, as a torch + tensor. Here's a sample desc.json file: { @@ -228,11 +158,26 @@ def run_tosa_ref_model( ), f"desc_file_path: {desc_file_path} does not exist" # Save the input data to disk as a .npy file, since that's what the TOSA - # reference model expects. Name of the file is must match the name in + # reference model expects. Name of the file must match the name in # desc.json, which is the tensor name from the graph + .npy - for tensor_name, data in tensor_names_and_inputs: - file_path = os.path.join(self.intermediate_path, tensor_name + ".npy") - np.save(file_path, data, allow_pickle=False) + for input_name, quant_param, data in zip( + params_input[0], params_input[1], inputs + ): + data_np = data.detach().numpy() + if self.profile is TosaProfile.BI: + assert ( + quant_param.node_name == input_name + ), "These quantization params do not match the input tensor name" + int8_max = np.iinfo(np.int8).max + int8_min = np.iinfo(np.int8).min + data_np = ( + ((data_np / np.float32(quant_param.scale)) + quant_param.zp) + .round() + .clip(int8_min, int8_max) + .astype(np.int8) + ) + file_path = os.path.join(self.intermediate_path, input_name + ".npy") + np.save(file_path, data_np, allow_pickle=False) # Run the TOSA reference model via command line, this will produce a # .npy file with the result (aka OFM). @@ -252,12 +197,11 @@ def run_tosa_ref_model( if self.profile is TosaProfile.BI: # Need to dequant back to FP32 for comparison with torch output - assert self.quantization.output.scales is not None - assert self.quantization.output.zps is not None - tosa_ref_output = ( - np.round(tosa_ref_output - self.quantization.output.zps) - * self.quantization.output.scales - ) + quant_param = param_output[1] + assert ( + quant_param is not None + ), "There are no qunatization parameters, check output parameters" + tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale # tosa_output is a numpy array, convert to torch tensor for comparison tosa_ref_output = torch.from_numpy(tosa_ref_output.astype("float32")) diff --git a/backends/vulkan/CMakeLists.txt b/backends/vulkan/CMakeLists.txt index 834f2704fc..24855545fe 100644 --- a/backends/vulkan/CMakeLists.txt +++ b/backends/vulkan/CMakeLists.txt @@ -8,7 +8,7 @@ # # This file should be formatted with # ~~~ -# cmake-format --first-comment-is-literal=True CMakeLists.txt +# cmake-format --first-comment-is-literal=True -i CMakeLists.txt # ~~~ # It should also be cmake-lint clean. # @@ -32,7 +32,10 @@ if(NOT FLATC_EXECUTABLE) set(FLATC_EXECUTABLE flatc) endif() -# Include this file to access target_link_options_shared_lib +# Include this file to access target_link_options_shared_lib This is required to +# provide access to target_link_options_shared_lib which allows libraries to be +# linked with the --whole-archive flag. This is required for libraries that +# perform dynamic registration via static initialization. include(${EXECUTORCH_ROOT}/build/Utils.cmake) # ATen Vulkan Libs @@ -48,12 +51,19 @@ set(COMMON_INCLUDES ${VULKAN_API_HEADERS} ${EXECUTORCH_ROOT}/..) file(GLOB_RECURSE vulkan_graph_cpp ${RUNTIME_PATH}/graph/*) add_library(vulkan_graph_lib STATIC ${vulkan_graph_cpp}) - target_include_directories(vulkan_graph_lib PRIVATE ${COMMON_INCLUDES}) +target_link_libraries(${LIBRARY_NAME} vulkan_api_lib) +target_compile_options(vulkan_graph_lib PRIVATE ${VULKAN_CXX_FLAGS}) +# Link this library with --whole-archive due to dynamic operator registrations +target_link_options_shared_lib(vulkan_graph_lib) -target_link_libraries(vulkan_graph_lib vulkan_shader_lib) +# Due to dynamic registrations, these libraries must be explicitly linked +set(VULKAN_STANDARD_OPS_LIBS vulkan_graph_lib vulkan_graph_shaderlib) -target_compile_options(vulkan_graph_lib PRIVATE ${VULKAN_CXX_FLAGS}) +# vulkan_graph_shaderlib + +set(VULKAN_GRAPH_SHADERS_PATH ${RUNTIME_PATH}/graph/ops/glsl/) +vulkan_shader_library(${VULKAN_GRAPH_SHADERS_PATH} vulkan_graph_shaderlib) # Generate Files from flatc @@ -85,19 +95,12 @@ target_include_directories( file(GLOB vulkan_backend_cpp ${RUNTIME_PATH}/*.cpp) add_library(vulkan_backend ${vulkan_backend_cpp}) - -target_include_directories(vulkan_backend PRIVATE ${SCHEMA_INCLUDE_DIR}) -target_include_directories(vulkan_backend PRIVATE ${COMMON_INCLUDES}) - -target_link_libraries(vulkan_backend PRIVATE vulkan_graph_lib) -target_link_libraries(vulkan_backend PRIVATE vulkan_schema) -target_link_libraries(vulkan_backend PRIVATE executorch) - +target_include_directories(vulkan_backend PRIVATE ${SCHEMA_INCLUDE_DIR} + ${COMMON_INCLUDES}) +target_link_libraries(vulkan_backend PRIVATE vulkan_graph_lib vulkan_schema + executorch) target_compile_options(vulkan_backend PRIVATE ${VULKAN_CXX_FLAGS}) - -# This is required to ensure that vulkan_backend gets linked with -# --whole-archive since backends are registered via static variables that would -# otherwise be discarded +# Link this library with --whole-archive due to dynamic backend registration target_link_options_shared_lib(vulkan_backend) # Executor Runner @@ -105,28 +108,38 @@ target_link_options_shared_lib(vulkan_backend) if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*iOS\.cmake$") set(VULKAN_RUNNER_SRCS ${_executor_runner__srcs}) list(TRANSFORM VULKAN_RUNNER_SRCS PREPEND "${EXECUTORCH_ROOT}/") + add_executable(vulkan_executor_runner ${VULKAN_RUNNER_SRCS}) - target_link_libraries(vulkan_executor_runner ${_executor_runner_libs}) - target_link_libraries(vulkan_executor_runner vulkan_schema) - target_link_libraries(vulkan_executor_runner vulkan_backend) + target_link_libraries( + vulkan_executor_runner ${_executor_runner_libs} vulkan_schema + vulkan_backend ${VULKAN_STANDARD_OPS_LIBS}) target_compile_options(vulkan_executor_runner PUBLIC ${VULKAN_CXX_FLAGS}) add_library(vulkan_executor_runner_lib STATIC ${VULKAN_RUNNER_SRCS}) - target_link_libraries(vulkan_executor_runner_lib ${_executor_runner_libs}) - target_link_libraries(vulkan_executor_runner_lib vulkan_schema) - target_link_libraries(vulkan_executor_runner_lib vulkan_backend) + target_link_libraries(vulkan_executor_runner_lib ${_executor_runner_libs} + vulkan_schema vulkan_backend) target_compile_options(vulkan_executor_runner_lib PUBLIC ${VULKAN_CXX_FLAGS}) endif() # Test targets if(EXECUTORCH_BUILD_GTESTS) + set(TEST_UTILS_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/test/utils) + file(GLOB TEST_UTILS_CPP ${CMAKE_CURRENT_SOURCE_DIR}/test/utils/*.cpp) + + set(TEST_SHADERS_PATH ${CMAKE_CURRENT_SOURCE_DIR}/test/glsl) + vulkan_shader_library(${TEST_SHADERS_PATH} test_shaderlib) + # vulkan_compute_api_test - set(TEST_CPP ${CMAKE_CURRENT_SOURCE_DIR}/test/vulkan_compute_api_test.cpp) - add_executable(vulkan_compute_api_test ${TEST_CPP}) - target_include_directories(vulkan_compute_api_test PRIVATE ${COMMON_INCLUDES}) - target_link_libraries(vulkan_compute_api_test vulkan_api_lib) - target_link_libraries(vulkan_compute_api_test vulkan_graph_lib) - target_link_libraries(vulkan_compute_api_test gtest_main) + set(COMPUTE_API_TEST_CPP + ${CMAKE_CURRENT_SOURCE_DIR}/test/vulkan_compute_api_test.cpp) + + add_executable(vulkan_compute_api_test ${COMPUTE_API_TEST_CPP} + ${TEST_UTILS_CPP}) + target_include_directories(vulkan_compute_api_test + PRIVATE ${COMMON_INCLUDES} ${TEST_UTILS_HEADERS}) + target_link_libraries( + vulkan_compute_api_test PRIVATE gtest_main ${VULKAN_STANDARD_OPS_LIBS} + test_shaderlib) target_compile_options(vulkan_compute_api_test PRIVATE ${VULKAN_CXX_FLAGS}) endif() diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 86733510a3..5bd6cf12f6 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -3,7 +3,7 @@ load(":targets.bzl", "define_common_targets") oncall("executorch") -define_common_targets() +define_common_targets(is_fbcode = True) runtime.python_library( name = "vulkan_preprocess", diff --git a/backends/vulkan/cmake/ATenVulkan.cmake b/backends/vulkan/cmake/ATenVulkan.cmake index 6db7868150..41a28823f6 100644 --- a/backends/vulkan/cmake/ATenVulkan.cmake +++ b/backends/vulkan/cmake/ATenVulkan.cmake @@ -8,7 +8,7 @@ # # This file should be formatted with # ~~~ -# cmake-format --first-comment-is-literal=True CMakeLists.txt +# cmake-format --first-comment-is-literal=True -i ATenVulkan.cmake # ~~~ # It should also be cmake-lint clean. # @@ -22,20 +22,6 @@ if(NOT VULKAN_THIRD_PARTY_PATH) set(VULKAN_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../third-party) endif() -# Shader Codegen - -# Trigger Shader code generation -set(USE_VULKAN ON) -set(VULKAN_CODEGEN_CMAKE_PATH ${PYTORCH_PATH}/cmake/VulkanCodegen.cmake) -if(NOT EXISTS ${VULKAN_CODEGEN_CMAKE_PATH}) - message( - FATAL_ERROR - "Cannot perform SPIR-V codegen because " ${VULKAN_CODEGEN_CMAKE_PATH} - " does not exist. Please make sure that submodules are initialized" - " and updated.") -endif() -include(${PYTORCH_PATH}/cmake/VulkanCodegen.cmake) - # Source paths and compile settings set(ATEN_PATH ${PYTORCH_PATH}/aten/src) @@ -65,19 +51,56 @@ list(APPEND VULKAN_API_HEADERS ${VOLK_PATH}) list(APPEND VULKAN_API_HEADERS ${VMA_PATH}) target_include_directories(vulkan_api_lib PRIVATE ${VULKAN_API_HEADERS}) - target_compile_options(vulkan_api_lib PRIVATE ${VULKAN_CXX_FLAGS}) -# vulkan_shader_lib +# Find GLSL compiler executable + +if(ANDROID) + if(NOT ANDROID_NDK) + message(FATAL_ERROR "ANDROID_NDK not set") + endif() + + set(GLSLC_PATH + "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc") +else() + find_program( + GLSLC_PATH glslc + PATHS ENV VULKAN_SDK + PATHS "$ENV{VULKAN_SDK}/${CMAKE_HOST_SYSTEM_PROCESSOR}/bin" + PATHS "$ENV{VULKAN_SDK}/bin") + + if(NOT GLSLC_PATH) + message(FATAL_ERROR "USE_VULKAN glslc not found") + endif() +endif() + +# Required to enable linking with --whole-archive +include(${EXECUTORCH_ROOT}/build/Utils.cmake) -file(GLOB VULKAN_IMPL_CPP ${ATEN_VULKAN_PATH}/impl/*.cpp) +# Convenience macro to create a shader library -add_library(vulkan_shader_lib STATIC ${VULKAN_IMPL_CPP} ${vulkan_generated_cpp}) +macro(vulkan_shader_library SHADERS_PATH LIBRARY_NAME) + set(VULKAN_SHADERGEN_ENV "") + set(VULKAN_SHADERGEN_OUT_PATH ${CMAKE_BINARY_DIR}/${LIBRARY_NAME}) -list(APPEND VULKAN_API_HEADERS ${CMAKE_BINARY_DIR}/vulkan) + execute_process( + COMMAND + "${PYTHON_EXECUTABLE}" ${PYTORCH_PATH}/tools/gen_vulkan_spv.py --glsl-path + ${SHADERS_PATH} --output-path ${VULKAN_SHADERGEN_OUT_PATH} + --glslc-path=${GLSLC_PATH} --tmp-dir-path=${VULKAN_SHADERGEN_OUT_PATH} + --env ${VULKAN_GEN_ARG_ENV} + RESULT_VARIABLE error_code) + set(ENV{PYTHONPATH} ${PYTHONPATH}) -target_include_directories(vulkan_shader_lib PRIVATE ${VULKAN_API_HEADERS}) + set(vulkan_generated_cpp ${VULKAN_SHADERGEN_OUT_PATH}/spv.cpp) -target_link_libraries(vulkan_shader_lib vulkan_api_lib) + add_library(${LIBRARY_NAME} STATIC ${vulkan_generated_cpp}) + target_include_directories(${LIBRARY_NAME} PRIVATE ${COMMON_INCLUDES}) + target_link_libraries(${LIBRARY_NAME} vulkan_api_lib) + target_compile_options(${LIBRARY_NAME} PRIVATE ${VULKAN_CXX_FLAGS}) + # Link this library with --whole-archive due to dynamic shader registrations + target_link_options_shared_lib(${LIBRARY_NAME}) -target_compile_options(vulkan_shader_lib PRIVATE ${VULKAN_CXX_FLAGS}) + unset(VULKAN_SHADERGEN_ENV) + unset(VULKAN_SHADERGEN_OUT_PATH) +endmacro() diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 84c9a132e2..12c8696c21 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from typing import final, List, Optional import torch @@ -30,6 +31,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, + operator.getitem, ] return supported diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 9c554a232c..62555adc73 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -23,12 +24,15 @@ #include /* strtol */ #include #include +#include namespace torch { namespace executor { namespace vulkan { namespace { +using namespace at::native::vulkan; + // Flatbuffer types using VkGraphPtr = const vkgraph::VkGraph*; using OpCallPtr = const vkgraph::OperatorCall*; @@ -51,102 +55,274 @@ const uint8_t* getConstantDataPtr( return constant_data + constant_bytes->offset(); } -using namespace at::native::vulkan; +api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { + switch (vk_datatype) { + case vkgraph::VkDataType::BOOL: + return api::kBool; + case vkgraph::VkDataType::UINT8: + return api::kByte; + case vkgraph::VkDataType::INT8: + return api::kChar; + case vkgraph::VkDataType::INT32: + return api::kInt; + case vkgraph::VkDataType::FLOAT16: + return api::kHalf; + case vkgraph::VkDataType::FLOAT32: + return api::kFloat; + } +} + +class GraphBuilder { + ComputeGraph* compute_graph_; + VkGraphPtr flatbuffer_; + const uint8_t* constant_data_; + + std::unordered_map ref_mapping_; -class VulkanBackend final : public PyTorchBackendInterface { public: - ~VulkanBackend() override = default; + explicit GraphBuilder( + ComputeGraph* compute_graph, + VkGraphPtr flatbuffer, + const uint8_t* constant_data) + : compute_graph_(compute_graph), + flatbuffer_(flatbuffer), + constant_data_(constant_data), + ref_mapping_() {} + + bool fb_id_exists(const uint32_t fb_id) { + const std::unordered_map::iterator found_ref = + ref_mapping_.find(fb_id); - bool is_available() const override { - return true; + return found_ref != ref_mapping_.end(); } - api::ScalarType get_scalar_type( - const vkgraph::VkDataType& vk_datatype) const { - switch (vk_datatype) { - case (vkgraph::VkDataType::fp32): { - return api::kFloat; - } - } + ValueRef get_fb_id_valueref(const uint32_t fb_id) { + const std::unordered_map::iterator found_ref = + ref_mapping_.find(fb_id); + + ET_CHECK_MSG( + found_ref != ref_mapping_.end(), + "Trying to extract a value that hasn't yet been added to the graph."); + + return found_ref->second; } - ValueRef get_value_ref( - const uint32_t value_id, - VkGraphPtr flatbuffer_graph, - ComputeGraph* compute_graph, - std::unordered_map& ref_mapping, - VkValuesVector value_mapping, - const uint8_t* constant_data) const { - const std::unordered_map::iterator found_ref = - ref_mapping.find(value_id); + void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) { + const api::ScalarType& dtype = get_scalar_type(tensor_fb->datatype()); + + UIntVector dims_fb = tensor_fb->dims(); + const std::vector dims_vector(dims_fb->cbegin(), dims_fb->cend()); + + ValueRef ref; + if (tensor_fb->constant_id() >= 0) { + const uint8_t* tensor_data = getConstantDataPtr( + flatbuffer_, tensor_fb->constant_id(), constant_data_); - if (found_ref != ref_mapping.end()) { - return found_ref->second; + ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data); + } else { + ref = compute_graph_->add_tensor( + dims_vector, dtype, tensor_fb->mem_obj_id()); } - VkValuePtr vk_value = value_mapping->Get(value_id); - VkTensorPtr vk_tensor = vk_value->value(); + ref_mapping_[fb_id] = ref; + } + + template + typename std::enable_if::value, void>::type + add_scalar_to_graph(const uint32_t fb_id, T value) { + ValueRef ref = compute_graph_->add_scalar(value); + ref_mapping_[fb_id] = ref; + } + + template + typename std::enable_if::value, void>::type + add_scalar_list_to_graph(const uint32_t fb_id, std::vector&& value) { + ValueRef ref = compute_graph_->add_scalar_list(std::move(value)); + ref_mapping_[fb_id] = ref; + } + + void add_value_list_to_graph( + const uint32_t fb_id, + std::vector&& value) { + ValueRef ref = compute_graph_->add_value_list(std::move(value)); + ref_mapping_[fb_id] = ref; + } + + void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) { + const auto fb_str = value->value_as_String()->string_val(); + std::string string(fb_str->cbegin(), fb_str->cend()); + ValueRef ref = compute_graph_->add_string(std::move(string)); + ref_mapping_[fb_id] = ref; + } + void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) { ET_CHECK_MSG( - vk_tensor->constant_id() >= 0, - "Only constant buffers are supported when adding tensors to compute graph (indicated by constant_id < 0), but got constant_id of %d", - vk_tensor->constant_id()); + !fb_id_exists(fb_id), + "Trying to add a value that has already been added to the graph."); + + switch (value->value_type()) { + case vkgraph::GraphTypes::Int: + add_scalar_to_graph(fb_id, value->value_as_Int()->int_val()); + break; + case vkgraph::GraphTypes::Double: + add_scalar_to_graph(fb_id, value->value_as_Double()->double_val()); + break; + case vkgraph::GraphTypes::Bool: + add_scalar_to_graph(fb_id, value->value_as_Bool()->bool_val()); + break; + case vkgraph::GraphTypes::VkTensor: + add_tensor_to_graph(fb_id, value->value_as_VkTensor()); + break; + case vkgraph::GraphTypes::IntList: + add_scalar_list_to_graph( + fb_id, + std::vector( + value->value_as_IntList()->items()->cbegin(), + value->value_as_IntList()->items()->cend())); + break; + case vkgraph::GraphTypes::DoubleList: + add_scalar_list_to_graph( + fb_id, + std::vector( + value->value_as_DoubleList()->items()->cbegin(), + value->value_as_DoubleList()->items()->cend())); + break; + case vkgraph::GraphTypes::BoolList: + add_scalar_list_to_graph( + fb_id, + std::vector( + value->value_as_BoolList()->items()->cbegin(), + value->value_as_BoolList()->items()->cend())); + break; + case vkgraph::GraphTypes::ValueList: + add_value_list_to_graph( + fb_id, + std::vector( + value->value_as_ValueList()->items()->cbegin(), + value->value_as_ValueList()->items()->cend())); + break; + case vkgraph::GraphTypes::String: + add_string_to_graph(fb_id, value); + break; + default: + ET_CHECK_MSG(false, "Unsupported value type."); + } + } - const api::ScalarType& tensor_dtype = - get_scalar_type(vk_tensor->datatype()); + void build_graph() { + // First, add all values to the graph + for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) { + VkValuePtr value = flatbuffer_->values()->Get(fb_id); + add_value_to_graph(fb_id, value); + } - UIntVector tensor_dims_fb = vk_tensor->dims(); - const std::vector tensor_dims_vector( - tensor_dims_fb->cbegin(), tensor_dims_fb->cend()); + // Parse the inputs + for (const uint32_t fb_id : *flatbuffer_->input_ids()) { + const ValueRef ref = get_fb_id_valueref(fb_id); + compute_graph_->set_input_tensor(ref); + } - const uint8_t* tensor_data = getConstantDataPtr( - flatbuffer_graph, vk_tensor->constant_id(), constant_data); + // Parse the operators + for (OpCallPtr op_call : *(flatbuffer_->chain())) { + std::string op_name = op_call->name()->str(); + ET_CHECK_MSG(VK_HAS_OP(op_name), "Missing operator: %s", op_name.c_str()); - const ValueRef value_ref = compute_graph->add_tensorref( - tensor_dims_vector, tensor_dtype, tensor_data); + const std::vector arg_fb_ids( + op_call->args()->cbegin(), op_call->args()->cend()); - ref_mapping[value_id] = value_ref; + std::vector args; + for (const int arg_fb_id : arg_fb_ids) { + args.push_back(get_fb_id_valueref(arg_fb_id)); + } - return value_ref; + auto vkFn = VK_GET_OP_FN(op_name); + vkFn(*compute_graph_, args); + } + + // Parse the outputs + for (const uint32_t fb_id : *flatbuffer_->output_ids()) { + const ValueRef ref = get_fb_id_valueref(fb_id); + compute_graph_->set_output_tensor(ref); + } + } +}; + +// +// Execution tools +// + +bool maybe_resize_input( + ComputeGraph* graph, + const size_t input_i, + exec_aten::Tensor& et_tensor) { + ValueRef in_tensor_ref = graph->inputs()[input_i].value; + vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor(); + + ET_CHECK_MSG( + et_tensor.dim() == in_tensor.sizes().size(), + "Cannot resize input tensor: old ndim %zu does not match new ndim %zu", + static_cast(in_tensor.sizes().size()), + static_cast(et_tensor.dim())); + + bool should_resize = false; + std::vector new_sizes(et_tensor.dim()); + for (size_t i = 0; i < et_tensor.dim(); i++) { + if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) { + should_resize = true; + } + new_sizes.at(i) = et_tensor.sizes()[i]; + } + + if (should_resize) { + graph->resize_input(input_i, new_sizes); } - GraphConfig generate_config() const { - const uint32_t submit_frequency = UINT32_MAX; - - const api::CommandPoolConfig cmd_config{ - 4u, // cmdPoolInitialSize - 2u, // cmdPoolBatchSize - }; - - const api::DescriptorPoolConfig descriptor_pool_config{ - 1024u, // descriptorPoolMaxSets - 1024u, // descriptorUniformBufferCount - 1024u, // descriptorStorageBufferCount - 1024u, // descriptorCombinedSamplerCount - 1024u, // descriptorStorageImageCount - 32u, // descriptorPileSizes - }; - - const api::QueryPoolConfig query_pool_config{}; - - const api::ContextConfig context_config{ - submit_frequency, // cmdSubmitFrequency - cmd_config, // cmdPoolConfig - descriptor_pool_config, // descriptorPoolConfig - query_pool_config, // queryPoolConfig - }; - - const GraphConfig graph_config{ - context_config, - }; - - return graph_config; + ET_CHECK_MSG( + in_tensor.numel() == et_tensor.numel(), + "Vulkan tensor numel %zu does not match ET tensor numel %zu", + static_cast(in_tensor.numel()), + static_cast(et_tensor.numel())); + + return should_resize; +} + +void maybe_resize_output( + ComputeGraph* graph, + const size_t output_i, + exec_aten::Tensor& et_tensor) { + ValueRef out_tensor_ref = graph->outputs()[output_i].value; + vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor(); + + exec_aten::SizesType new_output_size[kTensorDimensionLimit]; + size_t ndim = out_tensor.sizes().size(); + for (int i = 0; i < ndim; ++i) { + new_output_size[i] = out_tensor.sizes()[i]; + } + + exec_aten::ArrayRef output_size{new_output_size, ndim}; + Error err = resize_tensor(et_tensor, output_size); + + ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor."); +} + +// +// VulkanBackend class +// + +class VulkanBackend final : public PyTorchBackendInterface { + public: + ~VulkanBackend() override = default; + + bool is_available() const override { + // TODO(ssjia): replace with an actual Vulkan runtime availability check + return true; } __ET_NODISCARD Error compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const { Result header = VulkanDelegateHeader::Parse(buffer_pointer); + const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; @@ -169,92 +345,12 @@ class VulkanBackend final : public PyTorchBackendInterface { VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data); - // Mapping from serialized VkValue ids to compute graph ValueRefs - // This will be populated as the compute graph is built - std::unordered_map ref_mapping; - - // A vector which acts as a mapping from VkValue ids (vector indices) to - // VkValues - VkValuesVector value_mapping = flatbuffer_graph->values(); - - // 1. Add all inputs (and corresponding tensors) to the compute graph - UIntVector input_ids = flatbuffer_graph->input_ids(); - - for (size_t input_index = 0; input_index < input_ids->size(); - ++input_index) { - const uint32_t input_id = input_ids->Get(input_index); - VkValuePtr input_vk_value = value_mapping->Get(input_id); - - VkTensorPtr input_vk_tensor = input_vk_value->value(); - - ET_CHECK_MSG( - input_vk_tensor->constant_id() < 0, - "Expected constant buffer index for input at index %zu with id %d to be < 0 (since it is non-constant), but got: %d", - input_index, - input_id, - input_vk_tensor->constant_id()); - - const api::ScalarType& input_dtype = - get_scalar_type(input_vk_tensor->datatype()); + GraphBuilder builder = + GraphBuilder(compute_graph, flatbuffer_graph, constant_data); - UIntVector input_dims_fb = input_vk_tensor->dims(); - const std::vector input_dims_vector( - input_dims_fb->cbegin(), input_dims_fb->cend()); + builder.build_graph(); - const ValueRef input_ref = compute_graph->add_tensor( - input_dims_vector, input_dtype, input_vk_tensor->mem_obj_id()); - - ref_mapping[input_id] = input_ref; - compute_graph->set_input_tensor(input_ref); - } - - // 2. Add all ops to the graph - // TODO: Generalize for ops that don't have 2 inputs and 1 output. - for (OpCallPtr op_call : *(flatbuffer_graph->chain())) { - std::string op_name = op_call->name()->str(); - - ET_CHECK_MSG( - op_call->args() != nullptr && op_call->args()->size() == 3, - "Vulkan currently only supports OperatorCall with 3 args"); - const auto arg_ids = op_call->args()->data(); - - const uint32_t input1_id = arg_ids[0]; - const uint32_t input2_id = arg_ids[1]; - const uint32_t output_id = arg_ids[2]; - - const ValueRef input1_ref = get_value_ref( - input1_id, - flatbuffer_graph, - compute_graph, - ref_mapping, - value_mapping, - constant_data); - - const ValueRef input2_ref = get_value_ref( - input2_id, - flatbuffer_graph, - compute_graph, - ref_mapping, - value_mapping, - constant_data); - - ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str()); - auto vkFn = getOpsFn(op_name); - const at::native::vulkan::ValueRef output_ref = vkFn( - *compute_graph, - {input1_ref, - input2_ref, - 1, - value_mapping->Get(output_id)->value()->mem_obj_id()}); - - ref_mapping[output_id] = output_ref; - } - - // 3. Add all outputs to the compute graph - for (const uint32_t output_id : *flatbuffer_graph->output_ids()) { - const ValueRef output_ref = ref_mapping[output_id]; - compute_graph->set_output_tensor(output_ref); - } + compute_graph->prepare(); compute_graph->encode_prepack(); compute_graph->prepack(); @@ -271,7 +367,7 @@ class VulkanBackend final : public PyTorchBackendInterface { ComputeGraph* compute_graph = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR( context.get_runtime_allocator(), ComputeGraph); - new (compute_graph) ComputeGraph(generate_config()); + new (compute_graph) ComputeGraph(GraphConfig()); Error err = compileModel(processed->data(), compute_graph); @@ -291,20 +387,28 @@ class VulkanBackend final : public PyTorchBackendInterface { ComputeGraph* compute_graph = static_cast(handle); const size_t num_inputs = compute_graph->inputs().size(); + bool should_propagate_resize = false; for (size_t i = 0; i < num_inputs; i++) { + bool was_resized = + maybe_resize_input(compute_graph, i, args[i]->toTensor()); + should_propagate_resize = should_propagate_resize || was_resized; compute_graph->copy_into_staging( - compute_graph->inputs()[i], + compute_graph->inputs()[i].staging, args[i]->toTensor().const_data_ptr(), args[i]->toTensor().numel()); } + if (should_propagate_resize) { + compute_graph->propagate_resize(); + } compute_graph->execute(); for (size_t i = 0; i < compute_graph->outputs().size(); i++) { + maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor()); // args holds inputs directly followed by outputs, so the i'th output // for compute_graph corresponds to the (i + num_inputs)'th arg compute_graph->copy_from_staging( - compute_graph->outputs()[i], + compute_graph->outputs()[i].staging, args[num_inputs + i]->toTensor().mutable_data_ptr(), args[num_inputs + i]->toTensor().numel()); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 5adb5691e3..d51dc99689 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -6,16 +6,23 @@ * LICENSE file in the root directory of this source tree. */ +// @lint-ignore-every CLANGTIDY +// facebook-security-vulnerable-integer-sign-conversion + #include #include +#include + namespace at { namespace native { namespace vulkan { ComputeGraph::ComputeGraph(GraphConfig config) : config_{config}, + prepack_descriptor_counts_{}, + execute_descriptor_counts_{}, context_{new api::Context( api::runtime()->default_adapter_i(), config_.contextConfig)}, @@ -25,6 +32,19 @@ ComputeGraph::ComputeGraph(GraphConfig config) execute_nodes_{}, inputs_{}, outputs_{} { + // Ensure that descriptor counts are initialized to 0 + prepack_descriptor_counts_.descriptorPoolMaxSets = 0; + prepack_descriptor_counts_.descriptorUniformBufferCount = 0; + prepack_descriptor_counts_.descriptorStorageBufferCount = 0; + prepack_descriptor_counts_.descriptorCombinedSamplerCount = 0; + prepack_descriptor_counts_.descriptorStorageImageCount = 0; + + execute_descriptor_counts_.descriptorPoolMaxSets = 0; + execute_descriptor_counts_.descriptorUniformBufferCount = 0; + execute_descriptor_counts_.descriptorStorageBufferCount = 0; + execute_descriptor_counts_.descriptorCombinedSamplerCount = 0; + execute_descriptor_counts_.descriptorStorageImageCount = 0; + context_->set_cmd(/*reusable = */ true); } @@ -37,6 +57,33 @@ ComputeGraph::~ComputeGraph() { context_->flush(); } +void ComputeGraph::update_descriptor_counts( + const api::ShaderInfo& shader_info, + bool execute) { + api::DescriptorPoolConfig* config = + execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_; + + config->descriptorPoolMaxSets += 1; + for (const VkDescriptorType arg_type : shader_info.kernel_layout) { + switch (arg_type) { + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: + config->descriptorUniformBufferCount += 1; + break; + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: + config->descriptorStorageBufferCount += 1; + break; + case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: + config->descriptorCombinedSamplerCount += 1; + break; + case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: + config->descriptorStorageImageCount += 1; + break; + default: + VK_THROW("Unsupported descriptor type!"); + } + } +} + ValueRef ComputeGraph::add_tensor( const std::vector& sizes, const api::ScalarType dtype, @@ -75,17 +122,29 @@ ValueRef ComputeGraph::add_staging( return idx; } +ValueRef ComputeGraph::add_value_list(std::vector&& value) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(std::move(value)); + return idx; +} + +ValueRef ComputeGraph::add_string(std::string&& str) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(std::move(str)); + return idx; +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { if (use_staging) { vTensor& tensor = get_val(idx).toTensor(); ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel()); - execute_nodes_.emplace_back(new StagingNode(staging_idx, idx)); - inputs_.push_back(staging_idx); + add_staging_to_tensor_node(*this, staging_idx, idx); + inputs_.push_back({idx, staging_idx}); return staging_idx; } - inputs_.push_back(idx); + inputs_.push_back({idx, kDummyValueRef}); return idx; } @@ -95,11 +154,11 @@ ValueRef ComputeGraph::set_output_tensor( if (use_staging) { vTensor& tensor = get_val(idx).toTensor(); ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel()); - execute_nodes_.emplace_back(new StagingNode(idx, staging_idx)); - outputs_.push_back(staging_idx); + add_tensor_to_staging_node(*this, idx, staging_idx); + outputs_.push_back({idx, staging_idx}); return staging_idx; } - outputs_.push_back(idx); + outputs_.push_back({idx, kDummyValueRef}); return idx; } @@ -130,6 +189,30 @@ void ComputeGraph::copy_from_staging( copy_staging_to_ptr(staging, data, nbytes); } +void ComputeGraph::prepare() { +#define MERGE_FIELD(field) \ + static_cast(std::ceil( \ + std::max( \ + execute_descriptor_counts_.field, \ + prepack_descriptor_counts_.field) * \ + config_.descriptorPoolSafetyFactor)) + + uint32_t max_sets = MERGE_FIELD(descriptorPoolMaxSets); + api::DescriptorPoolConfig config{ + max_sets, + std::max(MERGE_FIELD(descriptorUniformBufferCount), max_sets), + std::max(MERGE_FIELD(descriptorStorageBufferCount), max_sets), + std::max(MERGE_FIELD(descriptorCombinedSamplerCount), max_sets), + std::max(MERGE_FIELD(descriptorStorageImageCount), max_sets), + 1u, + }; + + if (!context_->descriptor_pool()) { + context_->descriptor_pool().init(config); + } +#undef MERGE_FIELD +} + void ComputeGraph::encode_prepack() { for (std::unique_ptr& node : prepack_nodes_) { node->encode(this); @@ -165,6 +248,19 @@ void ComputeGraph::execute() const { fence.wait(); } +void ComputeGraph::resize_input( + const int64_t idx, + const std::vector& new_sizes) { + IOValueRef io_val = inputs_.at(idx); + get_val(io_val.value).toTensor().virtual_resize(new_sizes); +} + +void ComputeGraph::propagate_resize() { + for (std::unique_ptr& node : execute_nodes_) { + node->trigger_resize(this); + } +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index ec8d3ba1db..b5b2749dfb 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -12,9 +12,7 @@ #ifdef USE_VULKAN_API -#include -#include -#include +#include #include @@ -28,6 +26,19 @@ namespace at { namespace native { namespace vulkan { +// Define valid scalar types that the Value class can accept +template +struct is_valid_scalar_type : std::false_type {}; + +template <> +struct is_valid_scalar_type : std::true_type {}; + +template <> +struct is_valid_scalar_type : std::true_type {}; + +template <> +struct is_valid_scalar_type : std::true_type {}; + /* * This is the core data structure used to execute Vulkan models in graph mode. * As opposed to ATen/eager mode where a command buffer is encoded every @@ -47,6 +58,9 @@ class ComputeGraph final { private: GraphConfig config_; + api::DescriptorPoolConfig prepack_descriptor_counts_; + api::DescriptorPoolConfig execute_descriptor_counts_; + std::unique_ptr context_; std::vector shared_objects_; std::vector values_; @@ -54,8 +68,8 @@ class ComputeGraph final { std::vector> prepack_nodes_; std::vector> execute_nodes_; - std::vector inputs_; - std::vector outputs_; + std::vector inputs_; + std::vector outputs_; public: // @@ -66,14 +80,18 @@ class ComputeGraph final { return context_.get(); } - inline std::vector& inputs() { + inline std::vector& inputs() { return inputs_; } - inline std::vector& outputs() { + inline std::vector& outputs() { return outputs_; } + void update_descriptor_counts( + const api::ShaderInfo& shader_info, + bool execute); + /* * Returns the value at a particular reference */ @@ -123,9 +141,27 @@ class ComputeGraph final { const void* const data); ValueRef add_staging(const api::ScalarType dtype, const size_t numel); + template + typename std::enable_if::value, ValueRef>::type + add_scalar(T value); + + template + typename std::enable_if::value, ValueRef>::type + add_scalar_list(std::vector&& value); + + ValueRef add_value_list(std::vector&& value); + + ValueRef add_string(std::string&& str); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); + template + inline std::shared_ptr create_params_buffer( + const Block& data) { + return std::make_shared(context_.get(), data); + } + /* * Convenience function to add an input tensor along with its staging buffer */ @@ -140,6 +176,12 @@ class ComputeGraph final { SharedObject& get_shared_object(const int64_t idx); + // + // Graph Preparation + // + + void prepare(); + // // Input/Output // @@ -161,8 +203,31 @@ class ComputeGraph final { void encode_execute(); void execute() const; + + // + // Dynamic Shape support + // + + void resize_input(const int64_t idx, const std::vector& new_sizes); + void propagate_resize(); }; +template +inline typename std::enable_if::value, ValueRef>::type +ComputeGraph::add_scalar(T value) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(value); + return idx; +} + +template +inline typename std::enable_if::value, ValueRef>::type +ComputeGraph::add_scalar_list(std::vector&& value) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(std::move(value)); + return idx; +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp new file mode 100644 index 0000000000..8cda518dae --- /dev/null +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace at { +namespace native { +namespace vulkan { + +GraphConfig::GraphConfig() { + // No automatic submissions + const uint32_t submit_frequency = UINT32_MAX; + + // Only one command buffer will be encoded at a time + const api::CommandPoolConfig cmd_config{ + 1u, // cmdPoolInitialSize + 1u, // cmdPoolBatchSize + }; + + // Use lazy descriptor pool initialization by default; the graph runtime will + // tally up the number of descriptor sets needed while building the graph and + // trigger descriptor pool initialization with exact sizes before encoding the + // command buffer. + const api::DescriptorPoolConfig descriptor_pool_config{ + 0u, // descriptorPoolMaxSets + 0u, // descriptorUniformBufferCount + 0u, // descriptorStorageBufferCount + 0u, // descriptorCombinedSamplerCount + 0u, // descriptorStorageImageCount + 0u, // descriptorPileSizes + }; + + const api::QueryPoolConfig query_pool_config{}; + + const api::ContextConfig context_config{ + submit_frequency, // cmdSubmitFrequency + cmd_config, // cmdPoolConfig + descriptor_pool_config, // descriptorPoolConfig + query_pool_config, // queryPoolConfig + }; + + contextConfig = context_config; + + // Empirically selected safety factor. If descriptor pools start running out + // of memory, increase this safety factor. + descriptorPoolSafetyFactor = 1.25; +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index 0cb9bb6f53..ce0f0839a9 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -10,7 +10,7 @@ #ifdef USE_VULKAN_API -#include +#include namespace at { namespace native { @@ -18,6 +18,16 @@ namespace vulkan { struct GraphConfig final { api::ContextConfig contextConfig; + + // Creating a descriptor pool with exactly the number of descriptors tallied + // by iterating through the shader layouts of shaders used in the graph risks + // the descriptor pool running out of memory, therefore apply a safety factor + // to descriptor counts when creating the descriptor pool to mitigate this + // risk. + float descriptorPoolSafetyFactor; + + // Generate a default graph config with pre-configured settings + explicit GraphConfig(); }; } // namespace vulkan diff --git a/backends/vulkan/runtime/graph/containers/Types.cpp b/backends/vulkan/runtime/graph/containers/Types.cpp index bbfde572b0..0779ed8716 100644 --- a/backends/vulkan/runtime/graph/containers/Types.cpp +++ b/backends/vulkan/runtime/graph/containers/Types.cpp @@ -12,20 +12,25 @@ namespace at { namespace native { namespace vulkan { +#define PRINT_CASE(name) \ + case TypeTag::name: \ + out << #name; \ + break; + std::ostream& operator<<(std::ostream& out, const TypeTag& tag) { switch (tag) { - case TypeTag::NONE: - out << "NONE"; - break; - case TypeTag::TENSOR: - out << "TENSOR"; - break; - case TypeTag::STAGING: - out << "STAGING"; - break; - default: - out << "UNKNOWN"; - break; + PRINT_CASE(NONE) + PRINT_CASE(INT) + PRINT_CASE(DOUBLE) + PRINT_CASE(BOOL) + PRINT_CASE(TENSOR) + PRINT_CASE(STAGING) + PRINT_CASE(TENSORREF) + PRINT_CASE(INTLIST) + PRINT_CASE(DOUBLELIST) + PRINT_CASE(BOOLLIST) + PRINT_CASE(VALUELIST) + PRINT_CASE(STRING) } return out; } diff --git a/backends/vulkan/runtime/graph/containers/Types.h b/backends/vulkan/runtime/graph/containers/Types.h index a7162d777a..d5dee7ea0d 100644 --- a/backends/vulkan/runtime/graph/containers/Types.h +++ b/backends/vulkan/runtime/graph/containers/Types.h @@ -23,12 +23,21 @@ namespace vulkan { */ enum class TypeTag : uint32_t { NONE, - TENSOR, - STAGING, - TENSORREF, + // Scalar types INT, DOUBLE, BOOL, + // Tensor and tensor adjacent types + TENSOR, + STAGING, + TENSORREF, + // Scalar lists + INTLIST, + DOUBLELIST, + BOOLLIST, + // Special Type + VALUELIST, + STRING, }; std::ostream& operator<<(std::ostream& out, const TypeTag& tag); diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index d56791b4fa..82ba941713 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -22,6 +22,19 @@ namespace at { namespace native { namespace vulkan { +using ValueRef = int32_t; + +constexpr ValueRef kDummyValueRef = -1; + +inline bool is_valid(ValueRef value_ref) { + return value_ref >= 0; +} + +struct IOValueRef { + ValueRef value; + ValueRef staging; +}; + /* * This class is modelled after c10::IValue; however, it is simplified and does * not support as many types. However, the core design is the same; it is a @@ -48,6 +61,17 @@ struct Value final { api::StorageBuffer as_staging; TensorRef as_tensorref; + std::vector as_int_list; + std::vector as_double_list; + std::vector as_bool_list; + + // The below is a special type that is used to represent a list of other + // values stored in the graph. One application of the type is to represent + // a list of tensors or a list of optional tensors. + std::vector as_value_list; + + std::string as_string; + Payload() : u() {} // NOLINTNEXTLINE ~Payload(){}; @@ -68,21 +92,48 @@ struct Value final { Value& operator=(Value&&) = delete; +#define CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(type_tag, member_name) \ + case type_tag: \ + payload.u.member_name = rhs.payload.u.member_name; \ + break; + +#define CASE_MOVE_MOVEABLE_TYPE(type_tag, type, member_name) \ + case type_tag: \ + new (&payload.member_name) type(std::move(rhs.payload.member_name)); \ + break; + Value(Value&& rhs) noexcept : tag(rhs.tag) { - if (rhs.isTensor()) { - new (&payload.as_tensor) vTensor(std::move(rhs.payload.as_tensor)); - } else if (rhs.isStaging()) { - new (&payload.as_staging) - api::StorageBuffer(std::move(rhs.payload.as_staging)); - } else if (rhs.isTensorRef()) { - payload.as_tensorref = std::move(rhs.payload.as_tensorref); - } else { - payload.u = rhs.payload.u; + switch (tag) { + // Scalar types + CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int); + CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double); + CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool); + // Tensor and tensor adjacent types + CASE_MOVE_MOVEABLE_TYPE(TypeTag::TENSOR, vTensor, as_tensor); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::STAGING, api::StorageBuffer, as_staging); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::TENSORREF, TensorRef, as_tensorref); + // Scalar lists + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::INTLIST, std::vector, as_int_list); + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::DOUBLELIST, std::vector, as_double_list); + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::BOOLLIST, std::vector, as_bool_list); + // Special types + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::VALUELIST, std::vector, as_value_list); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::STRING, std::string, as_string); + + case TypeTag::NONE: + clearToNone(); + break; } - tag = rhs.tag; rhs.clearToNone(); } +#undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE +#undef CASE_MOVE_MOVEABLE_TYPE + // // Accessors // @@ -96,77 +147,127 @@ struct Value final { // ~Value() { - if (this->isTensor()) { - payload.as_tensor.~vTensor(); - } else if (this->isStaging()) { - payload.as_staging.~StorageBuffer(); - } else if (this->isTensorRef()) { - payload.as_tensorref.~TensorRef(); + switch (tag) { + case TypeTag::TENSOR: + payload.as_tensor.~vTensor(); + break; + case TypeTag::STAGING: + payload.as_staging.~StorageBuffer(); + break; + case TypeTag::TENSORREF: + payload.as_tensorref.~TensorRef(); + break; + case TypeTag::INTLIST: + payload.as_int_list.~vector(); + break; + case TypeTag::DOUBLELIST: + payload.as_double_list.~vector(); + break; + case TypeTag::BOOLLIST: + payload.as_bool_list.~vector(); + break; + case TypeTag::VALUELIST: + payload.as_value_list.~vector(); + break; + case TypeTag::STRING: + payload.as_string.~basic_string(); + break; + // Manually list out the types so that if a type here is added later and + // not handled the compiler can catch it. + case TypeTag::NONE: + case TypeTag::INT: + case TypeTag::DOUBLE: + case TypeTag::BOOL: + break; } } - // - // Tensor - // - - explicit Value(vTensor&& t) : tag(TypeTag::TENSOR) { - new (&payload.as_tensor) vTensor(std::move(t)); - } - - inline bool isTensor() const { - return TypeTag::TENSOR == tag; - } - - inline vTensor& toTensor() { - VK_CHECK_COND( - isTensor(), - "Expected value to have type TENSOR, got ", - tag, - " instead."); - return payload.as_tensor; +#define SUPPORT_TRIVIALLY_COPYABLE_TYPE( \ + type, type_name, type_tag, member_name) \ + explicit Value(type t) : tag(type_tag) { \ + payload.u.member_name = t; \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline const type& to##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return payload.u.member_name; \ } - // - // Staging - // - - explicit Value(api::StorageBuffer&& t) : tag(TypeTag::STAGING) { - new (&payload.as_staging) api::StorageBuffer(std::move(t)); - } - - inline bool isStaging() const { - return TypeTag::STAGING == tag; - } - - inline api::StorageBuffer& toStaging() { - VK_CHECK_COND( - isStaging(), - "Expected value to have type STAGING, got ", - tag, - " instead."); - return payload.as_staging; - } - - // - // TensorRef - // - - explicit Value(TensorRef&& t) : tag(TypeTag::TENSORREF) { - payload.as_tensorref = std::move(t); - } - - inline bool isTensorRef() const { - return TypeTag::TENSORREF == tag; + SUPPORT_TRIVIALLY_COPYABLE_TYPE(int64_t, Int, TypeTag::INT, as_int); + SUPPORT_TRIVIALLY_COPYABLE_TYPE(double, Double, TypeTag::DOUBLE, as_double); + SUPPORT_TRIVIALLY_COPYABLE_TYPE(bool, Bool, TypeTag::BOOL, as_bool); + +#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE + +#define SUPPORT_TRIVIALLY_MOVEABLE_TYPE( \ + type, type_name, type_tag, member_name) \ + explicit Value(type&& t) : tag(type_tag) { \ + new (&payload.member_name) type(std::move(t)); \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline type& to##type_name() { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return payload.member_name; \ } - inline TensorRef& toTensorRef() { - VK_CHECK_COND( - isTensorRef(), - "Expected value to have type TENSORREF, got ", - tag, - " instead."); - return payload.as_tensorref; - } + SUPPORT_TRIVIALLY_MOVEABLE_TYPE(vTensor, Tensor, TypeTag::TENSOR, as_tensor); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + api::StorageBuffer, + Staging, + TypeTag::STAGING, + as_staging); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + TensorRef, + TensorRef, + TypeTag::TENSORREF, + as_tensorref); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + IntList, + TypeTag::INTLIST, + as_int_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + DoubleList, + TypeTag::DOUBLELIST, + as_double_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + BoolList, + TypeTag::BOOLLIST, + as_bool_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + ValueList, + TypeTag::VALUELIST, + as_value_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::string, + String, + TypeTag::STRING, + as_string); + +#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE +#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE private: Payload payload; @@ -177,18 +278,11 @@ struct Value final { // inline void clearToNone() noexcept { - payload.u.as_int = 0; + payload.u.as_int = -1; tag = TypeTag::NONE; } }; -using ValueRef = int32_t; - -struct IOValueRef { - ValueRef value; - ValueRef staging; -}; - } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 6bdb07e719..e9d5ab18b4 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -10,12 +10,31 @@ #include -#include +#include namespace at { namespace native { namespace vulkan { +ExecuteNode::ExecuteNode( + ComputeGraph& graph, + const api::ShaderInfo& shader, + const api::utils::uvec3& global_workgroup_size, + const api::utils::uvec3& local_workgroup_size, + const std::vector& args, + const std::vector>& params, + const ResizeFunction& resize_fn, + const std::vector& resize_args) + : shader_(shader), + global_workgroup_size_(global_workgroup_size), + local_workgroup_size_(local_workgroup_size), + args_(args), + params_(params), + resize_fn_(resize_fn), + resize_args_(resize_args) { + graph.update_descriptor_counts(shader, /*execute = */ true); +} + void ExecuteNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); api::PipelineBarrier pipeline_barrier{}; @@ -27,20 +46,8 @@ void ExecuteNode::encode(ComputeGraph* graph) { uint32_t idx = 0; idx = bind_values_to_descriptor_set( - graph, - outputs_, - pipeline_barrier, - api::MemoryAccessType::WRITE, - descriptor_set, - idx); - idx = bind_values_to_descriptor_set( - graph, - inputs_, - pipeline_barrier, - api::MemoryAccessType::READ, - descriptor_set, - idx); - descriptor_set.bind(idx, params_.buffer()); + graph, args_, pipeline_barrier, descriptor_set, idx); + bind_params_to_descriptor_set(params_, descriptor_set, idx); context->register_shader_dispatch( descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 1b726e73d4..9d9beeab65 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -10,9 +10,7 @@ #ifdef USE_VULKAN_API -#include -#include -#include +#include #include @@ -22,48 +20,67 @@ namespace vulkan { class ComputeGraph; +/* + * Represents a group of shader arguments (images and/or buffers), with a common + * access permission. + */ +struct ArgGroup { + ArgGroup(const ValueRef ref, const api::MemoryAccessType access) + : refs{ref}, access(access) {} + + ArgGroup( + const std::vector& refs, + const api::MemoryAccessType access) + : refs(refs), access(access) {} + + const std::vector refs; + const api::MemoryAccessType access; +}; + /* * Represents a single execution op in a ML model. In graph mode, ops will be * implemented in a derived class that implements encode, which will implement * encoding of the shader corresponding to the op into the command buffer of a * ComputeGraph. */ -class ExecuteNode { +class ExecuteNode final { friend class ComputeGraph; public: - ExecuteNode(ValueRef input, ValueRef output) - : outputs_{output}, inputs_{input} {} + using ResizeFunction = const std::function&, + const std::vector&)>; ExecuteNode( + ComputeGraph& graph, const api::ShaderInfo& shader, const api::utils::uvec3& global_workgroup_size, const api::utils::uvec3& local_workgroup_size, - const std::vector& outputs, - const std::vector& inputs, - api::UniformParamsBuffer&& params) - : shader_(shader), - global_workgroup_size_(global_workgroup_size), - local_workgroup_size_(local_workgroup_size), - outputs_(outputs), - inputs_(inputs), - params_(std::move(params)) {} - - virtual ~ExecuteNode() = default; + const std::vector& args, + const std::vector>& params, + const ResizeFunction& resize_fn = nullptr, + const std::vector& resize_args = {}); + + ~ExecuteNode() = default; + + void encode(ComputeGraph* graph); + + inline void trigger_resize(ComputeGraph* graph) { + if (resize_fn_ != nullptr) { + resize_fn_(graph, args_, resize_args_); + } + } protected: - // TODO: Consider making members const after we remove StagingNode. - api::ShaderInfo shader_; - api::utils::uvec3 global_workgroup_size_; - api::utils::uvec3 local_workgroup_size_; - std::vector outputs_; - std::vector inputs_; - // TODO(T180906086): pass multiple buffers and index with ValueRef. + const api::ShaderInfo shader_; + const api::utils::uvec3 global_workgroup_size_; + const api::utils::uvec3 local_workgroup_size_; + const std::vector args_; // TODO(T180906457): allow re-computing param buffers. - api::UniformParamsBuffer params_; - - public: - virtual void encode(ComputeGraph* graph); + std::vector> params_; + const ResizeFunction resize_fn_; + const std::vector resize_args_; }; } // namespace vulkan diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp index 0d46e5b351..9f489e1c3f 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp @@ -8,48 +8,28 @@ #include -#include - namespace at { namespace native { namespace vulkan { -bool hasOpsFn(const std::string& name) { - return OperatorRegistry::getInstance().hasOpsFn(name); -} - -OpFunction& getOpsFn(const std::string& name) { - return OperatorRegistry::getInstance().getOpsFn(name); +bool OperatorRegistry::has_op(const std::string& name) { + return table_.count(name) > 0; } -OperatorRegistry& OperatorRegistry::getInstance() { - static OperatorRegistry instance; - return instance; +OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn( + const std::string& name) { + return table_.find(name)->second; } -bool OperatorRegistry::hasOpsFn(const std::string& name) { - return OperatorRegistry::kTable.count(name) > 0; +void OperatorRegistry::register_op(const std::string& name, OpFunction& fn) { + table_.insert(std::make_pair(name, fn)); } -OpFunction& OperatorRegistry::getOpsFn(const std::string& name) { - return OperatorRegistry::kTable.find(name)->second; +OperatorRegistry& operator_registry() { + static OperatorRegistry registry; + return registry; } -// @lint-ignore-every CLANGTIDY modernize-avoid-bind -// clang-format off -#define OPERATOR_ENTRY(name, function) \ - { #name, std::bind(&at::native::vulkan::function, std::placeholders::_1, std::placeholders::_2) } -// clang-format on - -const OperatorRegistry::OpTable OperatorRegistry::kTable = { - OPERATOR_ENTRY(aten.add.Tensor, add), - OPERATOR_ENTRY(aten.sub.Tensor, sub), - OPERATOR_ENTRY(aten.mul.Tensor, mul), - OPERATOR_ENTRY(aten.div.Tensor, div), - OPERATOR_ENTRY(aten.div.Tensor_mode, floor_div), - OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow), -}; - } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h index c11aa0168e..1088ab2e44 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h @@ -15,45 +15,68 @@ #include #include +#define VK_HAS_OP(name) ::at::native::vulkan::operator_registry().has_op(name) + +#define VK_GET_OP_FN(name) \ + ::at::native::vulkan::operator_registry().get_op_fn(name) + +#define VK_REGISTER_OP(name, function) \ + ::at::native::vulkan::operator_registry().register_op( \ + #name, \ + std::bind(&function, std::placeholders::_1, std::placeholders::_2)) + +#define REGISTER_OPERATORS \ + static void register_ops(); \ + static const OperatorRegisterInit reg(®ister_ops); \ + static void register_ops() + namespace at { namespace native { namespace vulkan { -using OpFunction = const std::function&)>; // TODO: Generalize to - // support float, - // int64_t. - -bool hasOpsFn(const std::string& name); +/* + * The Vulkan operator registry maps ATen operator names to their Vulkan + * delegate function implementation. It is a simplified version of + * executorch/runtime/kernel/operator_registry.h that uses the C++ Standard + * Library. + */ +class OperatorRegistry final { + using OpFunction = + const std::function&)>; + using OpTable = std::unordered_map; -OpFunction& getOpsFn(const std::string& name); + OpTable table_; -// The Vulkan operator registry is a simplified version of -// fbcode/executorch/runtime/kernel/operator_registry.h -// that uses the C++ Standard Library. -class OperatorRegistry { public: - static OperatorRegistry& getInstance(); + /* + * Check if the registry has an operator registered under the given name + */ + bool has_op(const std::string& name); - bool hasOpsFn(const std::string& name); - OpFunction& getOpsFn(const std::string& name); + /* + * Given an operator name, return the Vulkan delegate function + */ + OpFunction& get_op_fn(const std::string& name); - OperatorRegistry(const OperatorRegistry&) = delete; - OperatorRegistry(OperatorRegistry&&) = delete; - OperatorRegistry& operator=(const OperatorRegistry&) = delete; - OperatorRegistry& operator=(OperatorRegistry&&) = delete; + /* + * Register a function to a given operator name + */ + void register_op(const std::string& name, OpFunction& fn); +}; - private: - // TODO: Input string corresponds to target_name. We may need to pass kwargs. - using OpTable = std::unordered_map; - // @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration - static const OpTable kTable; +class OperatorRegisterInit final { + using InitFn = void(); - OperatorRegistry() = default; - ~OperatorRegistry() = default; + public: + explicit OperatorRegisterInit(InitFn* init_fn) { + init_fn(); + } }; +// The Vulkan operator registry is global. It is retrieved using this function, +// where it is declared as a static local variable. +OperatorRegistry& operator_registry(); + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp new file mode 100644 index 0000000000..c21c1447d9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +namespace at { +namespace native { +namespace vulkan { + +PrepackNode::PrepackNode( + ComputeGraph& graph, + const api::ShaderInfo& shader, + const api::utils::uvec3& global_workgroup_size, + const api::utils::uvec3& local_workgroup_size, + const ValueRef tref, + const ValueRef packed, + const std::vector>& params) + : shader_(shader), + global_workgroup_size_(global_workgroup_size), + local_workgroup_size_(local_workgroup_size), + tref_(tref), + packed_(packed), + params_(params) { + graph.update_descriptor_counts(shader, /*execute = */ false); +} + +void PrepackNode::encode(ComputeGraph* graph) { + api::Context* const context = graph->context(); + api::PipelineBarrier pipeline_barrier{}; + + TensorRef tref = graph->get_val(tref_).toTensorRef(); + vTensor packed = graph->get_val(packed_).toTensor(); + + // TODO: Extract to standalone function, to support other types of prepacking. + api::StorageBuffer staging( + graph->context(), packed.dtype(), packed.gpu_nbytes()); + size_t numel = api::utils::multiply_integers(tref.sizes); + size_t nbytes = numel * api::element_size(tref.dtype); + copy_ptr_to_staging(tref.data, staging, nbytes); + + std::unique_lock cmd_lock = context->dispatch_lock(); + + api::DescriptorSet descriptor_set = + context->get_descriptor_set(shader_, local_workgroup_size_); + + uint32_t idx = 0; + bind_tensor_to_descriptor_set( + packed, + pipeline_barrier, + api::MemoryAccessType::WRITE, + descriptor_set, + idx++); + bind_staging_to_descriptor_set(staging, descriptor_set, idx++); + bind_params_to_descriptor_set(params_, descriptor_set, idx); + + context->register_shader_dispatch( + descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 6f581eb931..7d8a8b4ce3 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -10,9 +10,7 @@ #ifdef USE_VULKAN_API -#include -#include -#include +#include #include @@ -28,20 +26,31 @@ class ComputeGraph; * encoding of shaders transferring necessary data (such as weights and biases) * to the GPU. */ -class PrepackNode { +class PrepackNode final { friend class ComputeGraph; public: - PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {} + PrepackNode( + ComputeGraph& graph, + const api::ShaderInfo& shader, + const api::utils::uvec3& global_workgroup_size, + const api::utils::uvec3& local_workgroup_size, + const ValueRef tref, + const ValueRef packed, + const std::vector>& params); - virtual ~PrepackNode() = default; + ~PrepackNode() = default; - protected: - ValueRef tref_; - ValueRef packed_; + void encode(ComputeGraph* graph); - public: - virtual void encode(ComputeGraph* graph) const = 0; + protected: + const api::ShaderInfo shader_; + const api::utils::uvec3 global_workgroup_size_; + const api::utils::uvec3 local_workgroup_size_; + const ValueRef tref_; + const ValueRef packed_; + // TODO(T180906457): allow re-computing param buffers. + std::vector> params_; }; } // namespace vulkan diff --git a/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml b/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml new file mode 100644 index 0000000000..a1abc6a745 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_op: + parameter_names_with_default_values: + OPERATOR: X + A * Y + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: binary_add + - NAME: binary_sub + OPERATOR: X - A * Y + - NAME: binary_mul + OPERATOR: X * Y + - NAME: binary_div + OPERATOR: X / Y + - NAME: binary_pow + OPERATOR: pow(X, Y) + - NAME: binary_floor_divide + OPERATOR: floor(X / Y) + +image_to_nchw: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: image3d_to_nchw_C_packed + - NAME: image2d_to_nchw_C_packed + NDIM: 2 + +nchw_to_image: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: nchw_to_image3d_C_packed + - NAME: nchw_to_image2d_C_packed + NDIM: 2 diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl new file mode 100644 index 0000000000..f7bcdaa232 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +#define OP(X, Y, A) ${OPERATOR} + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { + ivec4 data; +} +out_sizes; + +layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { + ivec4 data; +} +in_sizes; + +layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { + ivec4 data; +} +other_sizes; + +layout(set = 0, binding = 6) uniform PRECISION restrict Alpha { + float data; +} +alpha; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data); + + if (any(greaterThanEqual(coord, out_sizes.data))) { + return; + } + + ivec4 in_coord = out_coord_to_in_coord(coord, in_sizes.data); + vec4 in_texel = texelFetch( + image_in, + COORD_TO_POS_${PACKING}(in_coord, in_sizes.data), + 0); + + ivec4 other_coord = out_coord_to_in_coord(coord, other_sizes.data); + vec4 other_texel = texelFetch( + image_other, + COORD_TO_POS_${PACKING}(other_coord, other_sizes.data), + 0); + + imageStore(image_out, pos, OP(in_texel, other_texel, alpha.data)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h b/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h new file mode 100644 index 0000000000..dc8635b881 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +ivec4 out_coord_to_in_coord(const ivec4 out_coord, const ivec4 in_sizes) { + ivec4 in_coord = out_coord; + for (int i = 0; i < 4; ++i) { + if (in_sizes[i] == 1) { + in_coord[i] = 0; + } + } + return in_coord; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl new file mode 100644 index 0000000000..f966f7584b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_out; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const ${VEC4_T[DTYPE]} intex = texelFetch(image_in, ${GET_POS[NDIM]("pos")}, 0); + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + if (coord.z < cpu_sizes.data.z) { + buffer_out.data[buf_indices.x] = intex.x; + } + if (coord.z + 1 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.y] = intex.y; + } + if (coord.z + 2 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.z] = intex.z; + } + if (coord.z + 3 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.w] = intex.w; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h new file mode 100644 index 0000000000..7bac6b5116 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#define POS_TO_COORD_CHANNELS_PACKED(pos, sizes) \ + ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) + +#define COORD_TO_POS_CHANNELS_PACKED(coord, sizes) \ + ivec3(coord.x, coord.y, (coord.z + coord.w * sizes.z) / 4) + +#define COORD_TO_BUFFER_IDX(coord, sizes) \ + coord.x + coord.y* sizes.x + coord.z* sizes.y* sizes.x + \ + coord.w* sizes.z* sizes.y* sizes.x; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl new file mode 100644 index 0000000000..00ed3fe5e4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + ${T[DTYPE]} val_x = buffer_in.data[buf_indices.x]; + ${T[DTYPE]} val_y = buffer_in.data[buf_indices.y]; + ${T[DTYPE]} val_z = buffer_in.data[buf_indices.z]; + ${T[DTYPE]} val_w = buffer_in.data[buf_indices.w]; + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + if (coord.z + 3 >= cpu_sizes.data.z) { + ivec4 c_ind = ivec4(coord.z) + ivec4(0, 1, 2, 3); + vec4 valid_c = vec4(lessThan(c_ind, ivec4(cpu_sizes.data.z))); + texel = texel * valid_c; + } + + imageStore(image_out, ${GET_POS[NDIM]("pos")}, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp deleted file mode 100644 index ce43005384..0000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -namespace at { -namespace native { -namespace vulkan { - -#define DEFINE_ARITHMETIC_FN(function, shader) \ - ValueRef function(ComputeGraph& graph, const std::vector& args) { \ - return add_arithmetic_node( \ - graph, args[0], args[1], args[2], VK_KERNEL(shader), args[3]); \ - } - -DEFINE_ARITHMETIC_FN(add, add); -DEFINE_ARITHMETIC_FN(sub, sub); -DEFINE_ARITHMETIC_FN(mul, mul); -DEFINE_ARITHMETIC_FN(div, div); -DEFINE_ARITHMETIC_FN(floor_div, floor_divide); -DEFINE_ARITHMETIC_FN(pow, pow); - -// TODO(T180908843): Bypass this entrypoint function by creating `ValueRef out` -// ahead of time. -ValueRef add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const float alpha, - const api::ShaderInfo& shader, - const int64_t shared_object_idx) { - std::vector in1_sizes = graph.get_val_sizes(in1); - api::ScalarType in1_dtype = graph.get_val_dtype(in1); - - ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx); - add_arithmetic_node(graph, in1, in2, out, alpha, shader); - return out; -} - -// TODO(T181006464): Move to Utils when we remove ArithmeticPrepack. -ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) { - if (graph.get_val(v).isTensor()) { - return v; - } else { - TensorRef& tRef = graph.get_val(v).toTensorRef(); - ValueRef vTen = graph.add_tensor(tRef.sizes, tRef.dtype); - graph.prepack_nodes().emplace_back(new ArithmeticPrepack(v, vTen)); - return vTen; - } -} - -void add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const ValueRef out, - const float alpha, - const api::ShaderInfo& shader) { - ValueRef arg1 = prepack_if_tensor_ref(graph, in1); - ValueRef arg2 = prepack_if_tensor_ref(graph, in2); - - vTensor& t_in1 = graph.get_val(arg1).toTensor(); - vTensor& t_in2 = graph.get_val(arg2).toTensor(); - vTensor& t_out = graph.get_val(out).toTensor(); - - api::utils::uvec3 global_size = t_out.extents(); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - ArithmeticParams block{ - get_size_as_ivec4(t_out), - get_size_as_ivec4(t_in1), - get_size_as_ivec4(t_in2), - 1.0, - }; - api::UniformParamsBuffer params(graph.context(), block); - - graph.execute_nodes().emplace_back(new ExecuteNode( - shader, global_size, local_size, {out}, {arg1, arg2}, std::move(params))); -} - -ArithmeticPrepack::ArithmeticPrepack(const ValueRef tref, const ValueRef packed) - : PrepackNode(tref, packed) {} - -void ArithmeticPrepack::encode(ComputeGraph* graph) const { - TensorRef tref = graph->get_val(tref_).toTensorRef(); - vTensor packed = graph->get_val(packed_).toTensor(); - - api::StorageBuffer staging( - graph->context(), packed.dtype(), packed.gpu_nbytes()); - - size_t numel = api::utils::multiply_integers(tref.sizes); - size_t nbytes = numel * api::element_size(tref.dtype); - copy_ptr_to_staging(tref.data, staging, nbytes); - - encode_copy_to_vtensor(graph->context(), staging, packed); -} - -} // namespace vulkan -} // namespace native -} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h deleted file mode 100644 index 82e2aa2cdf..0000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#ifdef USE_VULKAN_API - -#include - -#include - -#include - -namespace at { -namespace native { -namespace vulkan { - -DECLARE_OP_FN(add); -DECLARE_OP_FN(sub); -DECLARE_OP_FN(mul); -DECLARE_OP_FN(div); -DECLARE_OP_FN(floor_div); -DECLARE_OP_FN(pow); - -ValueRef add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const float alpha, - const api::ShaderInfo& shader, - const int64_t shared_object_idx = -1); - -void add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const ValueRef out, - const float alpha, - const api::ShaderInfo& shader); - -struct ArithmeticParams final { - api::utils::ivec4 outputSizes; - api::utils::ivec4 input1Sizes; - api::utils::ivec4 input2Sizes; - float alpha; -}; - -class ArithmeticPrepack : public virtual PrepackNode { - public: - explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed); - - void encode(ComputeGraph* graph) const override; -}; - -} // namespace vulkan -} // namespace native -} // namespace at - -#endif /* USE_VULKAN_API */ diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp new file mode 100644 index 0000000000..1d637ecb34 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include + +namespace at { +namespace native { +namespace vulkan { + +void resize_binary_op_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensor& out = graph->get_val(args[0].refs[0]).toTensor(); + vTensor& self = graph->get_val(args[1].refs[0]).toTensor(); + vTensor& other = graph->get_val(args[1].refs[1]).toTensor(); + + std::vector new_out_sizes( + std::max(self.sizes().size(), other.sizes().size())); + + // Match the sizes in reverse because sizes are in NCHW order + for (int i = -1; i >= -new_out_sizes.size(); --i) { + new_out_sizes.at(new_out_sizes.size() + i) = std::max( + api::utils::val_at(i, self.sizes()), + api::utils::val_at(i, other.sizes())); + } + + out.virtual_resize(new_out_sizes); +} + +void add_binary_op_node( + ComputeGraph& graph, + const ValueRef in1, + const ValueRef in2, + const ValueRef alpha, + const ValueRef out, + const std::string& op_name) { + ValueRef arg1 = prepack_if_tensor_ref(graph, in1); + ValueRef arg2 = prepack_if_tensor_ref(graph, in2); + + vTensor& t_in1 = graph.get_val(arg1).toTensor(); + vTensor& t_in2 = graph.get_val(arg2).toTensor(); + vTensor& t_out = graph.get_val(out).toTensor(); + + api::utils::uvec3 global_size = t_out.virtual_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + float alpha_val = 1.0f; + // String is checked since floor_div passes in an unused string argument in + // place of alpha + if (is_valid(alpha) && !graph.get_val(alpha).isString()) { + alpha_val = extract_scalar(graph.get_val(alpha)); + } + + std::stringstream kernel_name; + kernel_name << "binary_" << op_name; + apply_dtype_suffix(kernel_name, t_out); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name.str()), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{arg1, arg2}, api::MemoryAccessType::READ}}, + // Shader params buffers + {t_out.gpu_sizes_ubo(), + t_in1.gpu_sizes_ubo(), + t_in2.gpu_sizes_ubo(), + graph.create_params_buffer(alpha_val)}, + // Resizing + resize_binary_op_node)); +} + +#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, args[0], args[1], args[2], args[3], #op_name); \ + } + +#define DEFINE_BINARY_OP_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, args[0], args[1], kDummyValueRef, args[2], #op_name); \ + } + +DEFINE_BINARY_OP_WITH_ALPHA_FN(add); +DEFINE_BINARY_OP_WITH_ALPHA_FN(sub); + +// Floor div does not have an alpha, but a string argument (which is unused) is +// passed in at the same location as the alpha argument in other op. +DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide); + +DEFINE_BINARY_OP_FN(mul); +DEFINE_BINARY_OP_FN(div); +DEFINE_BINARY_OP_FN(pow); + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.add.Tensor, add); + VK_REGISTER_OP(aten.sub.Tensor, sub); + VK_REGISTER_OP(aten.mul.Tensor, mul); + VK_REGISTER_OP(aten.div.Tensor, div); + VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide); + VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow); +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 5b16780777..b3319e6dac 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -8,116 +8,86 @@ #include -#include +#include + +#include +#include namespace at { namespace native { namespace vulkan { -void memcpy_to_mapping( - const void* src, - api::MemoryMap& dst_mapping, - const size_t nbytes, - const api::ScalarType dtype) { -#define DTYPE_CASE(ctype, vkformat, name) \ - case api::ScalarType::name: \ - memcpy_to_mapping_impl(src, dst_mapping, nbytes); \ - break; - - switch (dtype) { - VK_FORALL_SCALAR_TYPES(DTYPE_CASE) - default: - VK_THROW("Unrecognized dtype!"); - } -#undef DTYPE_CASE -} - -void memcpy_from_mapping( - api::MemoryMap& src_mapping, - void* dst, - const size_t nbytes, - const api::ScalarType dtype) { -#define DTYPE_CASE(ctype, vkformat, name) \ - case api::ScalarType::name: \ - memcpy_from_mapping_impl(src_mapping, dst, nbytes); \ - break; - - switch (dtype) { - VK_FORALL_SCALAR_TYPES(DTYPE_CASE) - default: - VK_THROW("Unrecognized dtype!"); - } -#undef DTYPE_CASE -} +void add_staging_to_tensor_node( + ComputeGraph& graph, + const ValueRef in_staging, + const ValueRef out_tensor) { + vTensor& t_out = graph.get_val(out_tensor).toTensor(); + VK_CHECK_COND(graph.get_val(in_staging).isStaging()); -void copy_ptr_to_staging( - const void* src, - api::StorageBuffer& staging, - const size_t nbytes) { - api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); - mapping.invalidate(); - memcpy_to_mapping(src, mapping, nbytes, staging.dtype()); -} + api::ShaderInfo shader = get_nchw_to_image_shader(t_out); -void copy_staging_to_ptr( - api::StorageBuffer& staging, - void* dst, - const size_t nbytes) { - api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ); - mapping.invalidate(); - memcpy_from_mapping(mapping, dst, nbytes, staging.dtype()); -} + api::utils::uvec3 global_size = t_out.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); -void encode_copy_to_vtensor( - api::Context* context, - api::StorageBuffer& staging, - vTensor& tensor) { - api::ShaderInfo shader = packing::get_nchw_to_image_shader(tensor); - api::PipelineBarrier pipeline_barrier{}; - packing::record_nchw_to_image_op( - context, + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, shader, - staging.buffer(), - tensor, - pipeline_barrier, - VK_NULL_HANDLE); + global_size, + local_size, + {{out_tensor, api::MemoryAccessType::WRITE}, + {in_staging, api::MemoryAccessType::READ}}, + {t_out.gpu_sizes_ubo(), t_out.cpu_sizes_ubo()})); } -void encode_copy_from_vtensor( - api::Context* context, - vTensor& tensor, - api::StorageBuffer& staging) { - api::ShaderInfo shader = packing::get_image_to_nchw_shader(tensor); - api::PipelineBarrier pipeline_barrier{}; - packing::record_image_to_nchw_op( - context, +void add_tensor_to_staging_node( + ComputeGraph& graph, + const ValueRef in_tensor, + const ValueRef out_staging) { + vTensor& t_in = graph.get_val(in_tensor).toTensor(); + VK_CHECK_COND(graph.get_val(out_staging).isStaging()); + + api::ShaderInfo shader = get_image_to_nchw_shader(t_in); + + api::utils::uvec3 global_size = t_in.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, shader, - tensor, - staging.buffer(), - pipeline_barrier, - VK_NULL_HANDLE); + global_size, + local_size, + {{in_tensor, api::MemoryAccessType::READ}, + {out_staging, api::MemoryAccessType::WRITE}}, + {t_in.gpu_sizes_ubo(), t_in.cpu_sizes_ubo()})); } -StagingNode::StagingNode(ValueRef from, ValueRef to) : ExecuteNode(from, to) {} +ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { + TensorRef& tref = graph.get_val(vref).toTensorRef(); + ValueRef v = graph.add_tensor(tref.sizes, tref.dtype); + vTensor t = graph.get_val(v).toTensor(); + + api::ShaderInfo shader = get_nchw_to_image_shader(t); -void StagingNode::encode(ComputeGraph* graph) { - Value& in_val = graph->get_val(inputs_[0]); - Value& out_val = graph->get_val(outputs_[0]); + api::utils::uvec3 global_size = t.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + global_size, + local_size, + vref, + v, + {t.gpu_sizes_ubo(), t.cpu_sizes_ubo()})); + + return v; +} - if (in_val.isStaging() && out_val.isTensor()) { - api::StorageBuffer& from_staging = graph->get_val(inputs_[0]).toStaging(); - vTensor& to_tensor = graph->get_val(outputs_[0]).toTensor(); - encode_copy_to_vtensor(graph->context(), from_staging, to_tensor); - } else if (in_val.isTensor() && out_val.isStaging()) { - vTensor& from_tensor = graph->get_val(inputs_[0]).toTensor(); - api::StorageBuffer& to_staging = graph->get_val(outputs_[0]).toStaging(); - encode_copy_from_vtensor(graph->context(), from_tensor, to_staging); +ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) { + if (graph.get_val(v).isTensorRef()) { + return prepack(graph, v); } else { - VK_THROW( - "Unexpected input value type ", - in_val.type(), - " and output value type ", - out_val.type()); + return v; } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index be57a9817f..425d77489f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -12,84 +12,23 @@ #include -#include +#include namespace at { namespace native { namespace vulkan { -// -// Functions to memcpy data into staging buffer -// +void add_staging_to_tensor_node( + ComputeGraph& graph, + const ValueRef in_staging, + const ValueRef out_tensor); -void memcpy_to_mapping( - const void* src, - api::MemoryMap& dst_mapping, - const size_t nbytes, - const api::ScalarType dtype); -void memcpy_from_mapping( - const api::MemoryMap& src_mapping, - void* dst, - const size_t nbytes, - const api::ScalarType dtype); +void add_tensor_to_staging_node( + ComputeGraph& graph, + const ValueRef in_tensor, + const ValueRef out_staging); -// -// Utility functions for memcpy -// - -template -void memcpy_to_mapping_impl( - const void* src, - api::MemoryMap& dst_mapping, - const size_t nbytes) { - T* data_ptr = dst_mapping.template data(); - memcpy(data_ptr, reinterpret_cast(src), nbytes); -} - -template -void memcpy_from_mapping_impl( - api::MemoryMap& src_mapping, - void* dst, - const size_t nbytes) { - T* data_ptr = src_mapping.template data(); - memcpy(reinterpret_cast(dst), data_ptr, nbytes); -} - -// -// Functions to copy data into and out of a staging buffer -// - -void copy_ptr_to_staging( - const void* src, - api::StorageBuffer& staging, - const size_t nbytes); -void copy_staging_to_ptr( - api::StorageBuffer& staging, - void* dst, - const size_t nbytes); - -// -// Functions to record copying data between a staging buffer and a vTensor -// - -void encode_copy_to_vtensor( - api::Context* context, - api::StorageBuffer& staging, - vTensor& tensor); -void encode_copy_from_vtensor( - api::Context* context, - vTensor& tensor, - api::StorageBuffer& staging); - -/* - * OpNode that allows copying data into and out of a staging buffer. - */ -class StagingNode : public virtual ExecuteNode { - public: - explicit StagingNode(ValueRef from, ValueRef to); - - void encode(ComputeGraph* graph) override; -}; +ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v); } // namespace vulkan } // namespace native diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h new file mode 100644 index 0000000000..299c3bb99f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { + +/* + * Maps a semantic dimension name to an integer that corresponds to its + * innermost ordering in a 4D tensor in NCHW format. Width is the innermost + * dimension, so it corresponds to 1, height is the next innermost, so it + * corresponds to 2, and so on. + */ +struct Dim4D { + static constexpr uint32_t Width = 1u; + static constexpr uint32_t Height = 2u; + static constexpr uint32_t Channel = 3u; + static constexpr uint32_t Batch = 4u; +}; + +/* + * Semantic dimension names for a 1D tensor + */ +struct Dim1D { + static constexpr uint32_t Length = 1u; +}; + +/* + * Semantic dimension names for a 2D Convolution kernel. + */ +struct DimConv2DKernel { + static constexpr uint32_t Width = 1u; + static constexpr uint32_t Height = 2u; + static constexpr uint32_t InChannels = 3u; + static constexpr uint32_t OutChannels = 4u; +}; + +/* + * The same as the above, except for a 2D Transposed Convolution kernel. + */ +struct DimTConv2DKernel { + static constexpr uint32_t Width = 1u; + static constexpr uint32_t Height = 2u; + static constexpr uint32_t OutChannels = 3u; + static constexpr uint32_t InChannels = 4u; +}; + +/* + * The functions below safely return the size of the dimension at the N-th + * innermost index. If the dimensionality of the size array is not sufficient + * then 1 will be returned. The structs above are intended to be used with + * these functions. + */ +template +uint32_t dim_at(const std::vector& sizes) { + const uint32_t dims = sizes.size(); + return dims < N ? 1 : api::utils::safe_downcast(sizes[dims - N]); +} + +template +uint32_t dim_at(const vTensor& v_in) { + return dim_at(v_in.sizes()); +} + +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h new file mode 100644 index 0000000000..38cb8eed3b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_VULKAN_API + +#include + +#include + +namespace at { +namespace native { +namespace vulkan { + +template +T extract_scalar(const Value& value) { + if (value.isInt()) { + return static_cast(value.toInt()); + } + if (value.isDouble()) { + return static_cast(value.toDouble()); + } + if (value.isBool()) { + return static_cast(value.toBool()); + } + VK_THROW("Cannot extract scalar from Value with type ", value.type()); +} + +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp new file mode 100644 index 0000000000..72e1bc5a0d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace at { +namespace native { +namespace vulkan { + +api::utils::uvec3 adaptive_work_group_size( + const api::utils::uvec3& global_work_group) { + api::utils::uvec3 local_group_size = {4, 4, 4}; + if (global_work_group.data[2u] == 1) { + if (global_work_group.data[1u] < 8) { + local_group_size.data[0u] = 16; + local_group_size.data[1u] = 4; + local_group_size.data[2u] = 1; + } else { + local_group_size.data[0u] = 8; + local_group_size.data[1u] = 8; + local_group_size.data[2u] = 1; + } + } + return local_group_size; +} + +api::utils::ivec4 get_size_as_ivec4(const vTensor& t) { + return api::utils::make_ivec4( + {dim_at(t), + dim_at(t), + dim_at(t), + dim_at(t)}); +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h new file mode 100644 index 0000000000..a01e8c0a4d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { + +api::utils::uvec3 adaptive_work_group_size( + const api::utils::uvec3& global_work_group); + +api::utils::ivec4 get_size_as_ivec4(const vTensor& t); + +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/backends/vulkan/runtime/graph/ops/Utils.cpp b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp similarity index 52% rename from backends/vulkan/runtime/graph/ops/Utils.cpp rename to backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp index 579eac54e3..6e1d9b3013 100644 --- a/backends/vulkan/runtime/graph/ops/Utils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp @@ -6,20 +6,12 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include namespace at { namespace native { namespace vulkan { -api::utils::ivec4 get_size_as_ivec4(const vTensor& t) { - return api::utils::make_ivec4( - {dim_at(t), - dim_at(t), - dim_at(t), - dim_at(t)}); -} - void bind_tensor_to_descriptor_set( vTensor& tensor, api::PipelineBarrier& pipeline_barrier, @@ -39,25 +31,49 @@ void bind_tensor_to_descriptor_set( uint32_t bind_values_to_descriptor_set( ComputeGraph* graph, - const std::vector& args, + const std::vector& args, api::PipelineBarrier& pipeline_barrier, - const api::MemoryAccessType accessType, api::DescriptorSet& descriptor_set, const uint32_t base_idx) { uint32_t idx = base_idx; for (auto& arg : args) { - Value& val = graph->get_val(arg); - if (val.isTensor()) { - vTensor& tensor = val.toTensor(); - bind_tensor_to_descriptor_set( - tensor, pipeline_barrier, accessType, descriptor_set, idx++); - } else { - VK_THROW("Unsupported type: ", val.type()); + for (auto& ref : arg.refs) { + Value& val = graph->get_val(ref); + if (val.isTensor()) { + bind_tensor_to_descriptor_set( + val.toTensor(), + pipeline_barrier, + arg.access, + descriptor_set, + idx++); + } else if (val.isStaging()) { + bind_staging_to_descriptor_set(val.toStaging(), descriptor_set, idx++); + } else { + VK_THROW("Unsupported type: ", val.type()); + } } } return idx; } +uint32_t bind_params_to_descriptor_set( + std::vector>& params, + api::DescriptorSet& descriptor_set, + const uint32_t base_idx) { + uint32_t idx = base_idx; + for (auto& param : params) { + descriptor_set.bind(idx++, param->buffer()); + } + return idx; +} + +void bind_staging_to_descriptor_set( + api::StorageBuffer& staging, + api::DescriptorSet& descriptor_set, + const uint32_t idx) { + descriptor_set.bind(idx, staging.buffer()); +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/Utils.h b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h similarity index 67% rename from backends/vulkan/runtime/graph/ops/Utils.h rename to backends/vulkan/runtime/graph/ops/utils/BindingUtils.h index 9cf214ca87..e8d508b791 100644 --- a/backends/vulkan/runtime/graph/ops/Utils.h +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h @@ -10,18 +10,15 @@ #ifdef USE_VULKAN_API -#include - #include namespace at { namespace native { namespace vulkan { -#define DECLARE_OP_FN(function) \ - ValueRef function(ComputeGraph& graph, const std::vector& args); - -api::utils::ivec4 get_size_as_ivec4(const vTensor& t); +// +// For objects in the graph +// void bind_tensor_to_descriptor_set( vTensor& tensor, @@ -32,12 +29,25 @@ void bind_tensor_to_descriptor_set( uint32_t bind_values_to_descriptor_set( ComputeGraph* graph, - const std::vector& args, + const std::vector& args, api::PipelineBarrier& pipeline_barrier, - const api::MemoryAccessType accessType, api::DescriptorSet& descriptor_set, const uint32_t base_idx); +// +// For objects NOT in the graph +// + +uint32_t bind_params_to_descriptor_set( + std::vector>& params, + api::DescriptorSet& descriptor_set, + const uint32_t base_idx); + +void bind_staging_to_descriptor_set( + api::StorageBuffer& staging, + api::DescriptorSet& descriptor_set, + const uint32_t idx); + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp new file mode 100644 index 0000000000..e941f32e16 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace at { +namespace native { +namespace vulkan { + +void apply_dtype_suffix(std::stringstream& kernel_name, const vTensor& tensor) { + switch (tensor.image().format()) { + case VK_FORMAT_R32G32B32A32_SFLOAT: + kernel_name << "_float"; + break; + case VK_FORMAT_R16G16B16A16_SFLOAT: + kernel_name << "_half"; + break; + case VK_FORMAT_R32G32B32A32_SINT: + kernel_name << "_int"; + break; + default: + break; + } +} + +void apply_memory_layout_suffix( + std::stringstream& kernel_name, + const vTensor& tensor) { + switch (tensor.gpu_memory_layout()) { + case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED: + kernel_name << "_C_packed"; + break; + case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED: + kernel_name << "_H_packed"; + break; + case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED: + kernel_name << "_W_packed"; + break; + default: + break; + } +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h new file mode 100644 index 0000000000..b4c6c3a6bc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_VULKAN_API + +#include + +#include + +namespace at { +namespace native { +namespace vulkan { + +void apply_dtype_suffix(std::stringstream& kernel_name, const vTensor& tensor); + +void apply_memory_layout_suffix( + std::stringstream& kernel_name, + const vTensor& tensor); + +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp new file mode 100644 index 0000000000..45307c8a9d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// @lint-ignore-every CLANGTIDY facebook-security-vulnerable-memcpy + +#include +#include + +#include + +#include + +namespace at { +namespace native { +namespace vulkan { + +template +void memcpy_to_mapping_impl( + const void* src, + api::MemoryMap& dst_mapping, + const size_t nbytes) { + T* data_ptr = dst_mapping.template data(); + memcpy(data_ptr, reinterpret_cast(src), nbytes); +} + +template +void memcpy_from_mapping_impl( + api::MemoryMap& src_mapping, + void* dst, + const size_t nbytes) { + T* data_ptr = src_mapping.template data(); + memcpy(reinterpret_cast(dst), data_ptr, nbytes); +} + +void memcpy_to_mapping( + const void* src, + api::MemoryMap& dst_mapping, + const size_t nbytes, + const api::ScalarType dtype) { +#define DTYPE_CASE(ctype, vkformat, name) \ + case api::ScalarType::name: \ + memcpy_to_mapping_impl(src, dst_mapping, nbytes); \ + break; + + switch (dtype) { + VK_FORALL_SCALAR_TYPES(DTYPE_CASE) + default: + VK_THROW("Unrecognized dtype!"); + } +#undef DTYPE_CASE +} + +void memcpy_from_mapping( + api::MemoryMap& src_mapping, + void* dst, + const size_t nbytes, + const api::ScalarType dtype) { +#define DTYPE_CASE(ctype, vkformat, name) \ + case api::ScalarType::name: \ + memcpy_from_mapping_impl(src_mapping, dst, nbytes); \ + break; + + switch (dtype) { + VK_FORALL_SCALAR_TYPES(DTYPE_CASE) + default: + VK_THROW("Unrecognized dtype!"); + } +#undef DTYPE_CASE +} + +void copy_ptr_to_staging( + const void* src, + api::StorageBuffer& staging, + const size_t nbytes) { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + mapping.invalidate(); + memcpy_to_mapping(src, mapping, nbytes, staging.dtype()); +} + +void copy_staging_to_ptr( + api::StorageBuffer& staging, + void* dst, + const size_t nbytes) { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ); + mapping.invalidate(); + memcpy_from_mapping(mapping, dst, nbytes, staging.dtype()); +} + +api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) { + if (v_dst.is_quantized()) { + VK_THROW("Quantized Tensors are currently not supported!"); + } + + std::stringstream kernel_name; + + switch (v_dst.storage_type()) { + case api::StorageType::TEXTURE_3D: + kernel_name << "nchw_to_image3d"; + break; + case api::StorageType::TEXTURE_2D: + kernel_name << "nchw_to_image2d"; + break; + default: + VK_THROW("No kernel available!"); + } + + apply_memory_layout_suffix(kernel_name, v_dst); + apply_dtype_suffix(kernel_name, v_dst); + + return VK_KERNEL_FROM_STR(kernel_name.str()); +} + +api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) { + if (v_src.is_quantized()) { + VK_THROW("Quantized Tensors are currently not supported!"); + } + + std::stringstream kernel_name; + + switch (v_src.storage_type()) { + case api::StorageType::TEXTURE_3D: + kernel_name << "image3d_to_nchw"; + break; + case api::StorageType::TEXTURE_2D: + kernel_name << "image2d_to_nchw"; + break; + default: + VK_THROW("No kernel available!"); + } + + apply_memory_layout_suffix(kernel_name, v_src); + apply_dtype_suffix(kernel_name, v_src); + + return VK_KERNEL_FROM_STR(kernel_name.str()); +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h new file mode 100644 index 0000000000..2e5de6efb0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { + +// +// Functions to copy data into and out of a staging buffer +// + +void copy_ptr_to_staging( + const void* src, + api::StorageBuffer& staging, + const size_t nbytes); +void copy_staging_to_ptr( + api::StorageBuffer& staging, + void* dst, + const size_t nbytes); + +// +// Functions to get shaders +// + +api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst); +api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src); + +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 3d8dab9a2f..36f6120025 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -10,9 +10,13 @@ table OperatorCall { args:[int]; } -enum VkDataType : short { - // IEEE754 single-precision floating-point. - fp32 = 0, +enum VkDataType : byte { + BOOL = 0, + UINT8 = 1, + INT8 = 2, + INT32 = 3, + FLOAT16 = 4, + FLOAT32 = 5, } table VkTensor { @@ -26,8 +30,55 @@ table VkTensor { mem_obj_id:int; } +table Null {} + +table Int { + int_val:long; +} + +table Bool { + bool_val:bool; +} + +table Double { + double_val:double; +} + +table String { + string_val:string; +} + +table IntList { + items:[long]; +} + +table DoubleList { + items:[double]; +} + +table BoolList { + items:[bool]; +} + +table ValueList { + items:[int]; +} + +union GraphTypes { + Null, + Int, + Double, + Bool, + VkTensor, + IntList, + DoubleList, + BoolList, + ValueList, + String, +} + table VkValue { - value:VkTensor; + value:GraphTypes; } // Abstraction to represent a region of bytes in a raw data buffer. Useful for referencing raw data diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 68e54c2bc3..f15e155703 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +import operator +from typing import cast, List, Optional, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema @@ -15,6 +16,9 @@ from torch.export import ExportedProgram from torch.fx import Node +_ScalarType = Union[bool, int, float] +_Argument = Union[Node, List[Node], TensorSpec, _ScalarType, List[_ScalarType], str] + class VkGraphBuilder: def __init__(self, program: ExportedProgram) -> None: @@ -26,28 +30,41 @@ def __init__(self, program: ExportedProgram) -> None: self.output_ids = [] self.const_tensors = [] - # Mapping from torch.fx.Node to VkValue id + # Mapping from Node to VkValue id self.node_to_value_ids = {} @staticmethod def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: - if torch_dtype == torch.float32: - return vk_graph_schema.VkDataType.fp32 + if torch_dtype == torch.bool: + return vk_graph_schema.VkDataType.BOOL + elif torch_dtype == torch.uint8: + return vk_graph_schema.VkDataType.UINT8 + elif torch_dtype == torch.int8: + return vk_graph_schema.VkDataType.INT8 + elif torch_dtype == torch.int32: + return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float16: + return vk_graph_schema.VkDataType.FLOAT16 + elif torch_dtype == torch.float32: + return vk_graph_schema.VkDataType.FLOAT32 + # Narrowing conversion for index tensor produced by max_poolNd_with_indices. + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT32 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - def is_constant(self, node: torch.fx.Node): + def is_constant(self, node: Node): return ( node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants ) - def is_get_attr_node(self, node: torch.fx.Node) -> bool: + def is_get_attr_node(self, node: Node) -> bool: """ Returns true if the given node is a get attr node for a tensor of the model """ - return isinstance(node, torch.fx.Node) and node.op == "get_attr" + return isinstance(node, Node) and node.op == "get_attr" - def is_param_node(self, node: torch.fx.Node) -> bool: + def is_param_node(self, node: Node) -> bool: """ Check if the given node is a parameter within the exported program """ @@ -58,7 +75,7 @@ def is_param_node(self, node: torch.fx.Node) -> bool: or self.is_constant(node) ) - def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]: + def get_constant(self, node: Node) -> Optional[torch.Tensor]: """ Returns the constant associated with the given node in the exported program. Returns None if the node is not a constant within the exported program @@ -76,7 +93,7 @@ def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]: return None - def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor: + def get_param_tensor(self, node: Node) -> torch.Tensor: tensor = None if node is None: raise RuntimeError("node is None") @@ -99,33 +116,20 @@ def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor: return tensor def maybe_add_constant_tensor(self, node: Node) -> int: - const_buffer_idx = -1 + constant_id = -1 if self.is_param_node(node): - const_buffer_idx = len(self.const_tensors) + constant_id = len(self.const_tensors) self.const_tensors.append(self.get_param_tensor(node)) - return const_buffer_idx - - def create_single_vk_value(self, node: Node) -> int: - constant_id = self.maybe_add_constant_tensor(node) - - spec = node.meta.get("spec") - assert isinstance(spec, TensorSpec) - new_id = len(self.values) - if node not in self.node_to_value_ids: - self.node_to_value_ids[node] = new_id - else: - current_ids = self.node_to_value_ids[node] - if isinstance(current_ids, int): - current_ids = [current_ids, new_id] - else: - current_ids.append(new_id) + return constant_id + def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # Negative id indicates that this tensor will have its own dedicated memory. mem_obj_id = -1 if spec.mem_obj_id is not None: mem_obj_id = spec.mem_obj_id + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( @@ -138,60 +142,158 @@ def create_single_vk_value(self, node: Node) -> int: ) return new_id - def create_vk_values_for(self, node: Node): + def create_node_value(self, node: Node) -> int: spec = node.meta.get("spec") if isinstance(spec, TensorSpec): - return self.create_single_vk_value(node) + constant_id = self.maybe_add_constant_tensor(node) + new_id = self.create_tensor_value(spec, constant_id) + self.node_to_value_ids[node] = new_id + return new_id + elif isinstance(spec, tuple): + # Create a Value for each element in the tuple, wrap Values in a + # ValueList, and map the Node to the ValueList id. + new_id = self.create_value_list_value(spec) + self.node_to_value_ids[node] = new_id + return new_id else: - raise RuntimeError( - "Creating values for nodes with collection types is not supported yet." + raise RuntimeError(f"Cannot create value for spec of type {type(spec)}") + + def create_value_list_value(self, arg: List[Node] | tuple) -> int: + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.ValueList( + items=[self.get_or_create_value_for(e) for e in arg] + ) + ) + ) + return len(self.values) - 1 + + def create_scalar_value(self, scalar: _ScalarType) -> int: + new_id = len(self.values) + if isinstance(scalar, bool): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar))) + elif isinstance(scalar, int): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) + elif isinstance(scalar, float): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) + return new_id + + def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: + new_id = len(self.values) + if isinstance(arg[0], bool): + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg]) + ) + ) + elif isinstance(arg[0], int): + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.IntList(items=[cast(int, e) for e in arg]) + ) + ) + elif isinstance(arg[0], float): + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg]) + ) ) + return new_id + + def create_string_value(self, string: str) -> int: + new_id = len(self.values) + self.values.append( + vk_graph_schema.VkValue(vk_graph_schema.String(string_val=string)) + ) + return new_id + + def get_or_create_value_for(self, arg: _Argument): + if isinstance(arg, Node): + # If the Node has already been processed, return the existing id. + if arg in self.node_to_value_ids: + return self.node_to_value_ids[arg] + return self.create_node_value(arg) + elif isinstance(arg, list) and isinstance(arg[0], Node): + # pyre-ignore[6] + return self.create_value_list_value(arg) + elif isinstance(arg, TensorSpec): + return self.create_tensor_value(arg) + elif isinstance(arg, _ScalarType): + return self.create_scalar_value(arg) + elif isinstance(arg, list) and isinstance(arg[0], _ScalarType): + # pyre-ignore[6] + return self.create_scalar_list_value(arg) + elif isinstance(arg, str): + return self.create_string_value(arg) + else: + raise RuntimeError(f"Cannot create value for arg of type {type(arg)}") def process_placeholder_node(self, node: Node) -> None: - ids = self.create_vk_values_for(node) + ids = self.create_node_value(node) if not self.is_param_node(node): if isinstance(ids, int): self.input_ids.append(ids) else: self.input_ids += ids + def process_getitem_node(self, node: Node) -> None: + # Find ValueList id from the collection node. + collection_node = node.all_input_nodes[0] + list_id = self.node_to_value_ids[collection_node] + + # Extract the target Value id from ValueList. + valuelist_id = node.args[1] + value_id = self.values[list_id].value.items[valuelist_id] + + # Map Node to Value id. + self.node_to_value_ids[node] = value_id + def process_call_function_node(self, node) -> None: - args = [] - # Add input nodes - for inp_node in node.all_input_nodes: - if inp_node not in self.node_to_value_ids: - raise AssertionError( - "Cannot find input to current node in node_to_value_ids. This means " - "this node is being serialized before its input which is not allowed." - ) - args.append(self.node_to_value_ids[inp_node]) + operator_call_args = [] + + for i, schema_arg in enumerate(node.target._schema.arguments): + if not schema_arg.kwarg_only and i < len(node.args): + function_arg = node.args[i] + elif schema_arg.name in node.kwargs: + function_arg = node.kwargs[schema_arg.name] + else: + function_arg = schema_arg.default_value + + # Create a Value for each function argument. If the argument has been + # previously encountered, then use the existing Value id. + operator_call_args.append(self.get_or_create_value_for(function_arg)) + # Add output node - args.append(self.create_vk_values_for(node)) + operator_call_args.append(self.create_node_value(node)) self.chain.append( vk_graph_schema.OperatorCall( name=node.target.__name__, - args=args, + args=operator_call_args, ), ) def process_getattr_node(self, node: Node) -> None: - self.create_vk_values_for(node) + self.create_node_value(node) def process_output_node(self, node: Node) -> None: - if node.all_input_nodes[0] not in self.node_to_value_ids: - raise AssertionError( - "Cannot find input to output node in node_to_value_ids. This means the " - "output node is being serialized before its corresponding internal node " - "which is not allowed." - ) - self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]]) + for out_node in node.all_input_nodes: + if out_node not in self.node_to_value_ids: + raise AssertionError( + "Cannot find input to output node in node_to_value_ids. This means " + "the output node is being serialized before its corresponding " + "internal node which is not allowed." + ) + self.output_ids.append(self.node_to_value_ids[out_node]) def process_node(self, node: Node) -> None: if node.op == "placeholder": self.process_placeholder_node(node) elif node.op == "call_function": - self.process_call_function_node(node) + if node.target == operator.getitem: + self.process_getitem_node(node) + else: + self.process_call_function_node(node) elif node.op == "get_attr": self.process_getattr_node(node) elif node.op == "output": diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index eeb1589a2a..2edb02a910 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import IntEnum -from typing import List +from typing import List, Union @dataclass @@ -22,7 +22,12 @@ class OperatorCall: class VkDataType(IntEnum): - fp32 = 0 + BOOL = 0 + UINT8 = 1 + INT8 = 2 + INT32 = 3 + FLOAT16 = 4 + FLOAT32 = 5 @dataclass @@ -34,13 +39,67 @@ class VkTensor: @dataclass -class VkScalar: +class Null: pass +@dataclass +class Int: + int_val: int + + +@dataclass +class Bool: + bool_val: bool + + +@dataclass +class Double: + double_val: float + + +@dataclass +class IntList: + items: List[int] + + +@dataclass +class DoubleList: + items: List[float] + + +@dataclass +class BoolList: + items: List[bool] + + +@dataclass +class ValueList: + items: List[int] + + +@dataclass +class String: + string_val: str + + +GraphTypes = Union[ + Null, + Int, + Double, + Bool, + VkTensor, + IntList, + BoolList, + DoubleList, + ValueList, + String, +] + + @dataclass class VkValue: - value: VkTensor + value: "GraphTypes" @dataclass diff --git a/backends/vulkan/serialization/vulkan_graph_serialize.py b/backends/vulkan/serialization/vulkan_graph_serialize.py index 83a9e75f6c..37785f4752 100644 --- a/backends/vulkan/serialization/vulkan_graph_serialize.py +++ b/backends/vulkan/serialization/vulkan_graph_serialize.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from typing import ClassVar, List -# pyre-ignore[21]: Could not find module `executorch.exir._serialize._bindings`. import pkg_resources import torch @@ -35,7 +34,6 @@ def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes: json_path = os.path.join(d, "schema.json") with open(json_path, "wb") as json_file: json_file.write(vk_graph_json.encode("ascii")) - # pyre-ignore _flatc_compile(d, schema_path, json_path) output_path = os.path.join(d, "schema.bin") with open(output_path, "rb") as output_file: diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 345f18801f..1e7670d1cc 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -1,6 +1,55 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -def define_common_targets(): +def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False): + gen_aten_vulkan_spv_target = "//caffe2/tools:gen_aten_vulkan_spv_bin" + glslc_path = "//caffe2/fb/vulkan/dotslash:glslc" + if is_fbcode: + gen_aten_vulkan_spv_target = "//caffe2:gen_vulkan_spv_bin" + glslc_path = "//caffe2/fb/vulkan/tools:glslc" + + glsl_paths = [] + + # TODO(ssjia): remove the need for subpath once subdir_glob is enabled in OSS + for target, subpath in spv_filegroups.items(): + glsl_paths.append("$(location {})/{}".format(target, subpath)) + + genrule_cmd = [ + "$(exe {})".format(gen_aten_vulkan_spv_target), + "--glsl-paths {}".format(" ".join(glsl_paths)), + "--output-path $OUT", + "--glslc-path=$(exe {})".format(glslc_path), + "--tmp-dir-path=$OUT", + ] + + genrule_name = "gen_{}_cpp".format(name) + runtime.genrule( + name = genrule_name, + outs = { + "{}.cpp".format(name): ["spv.cpp"], + }, + cmd = " ".join(genrule_cmd), + default_outs = ["."], + labels = ["uses_dotslash"], + ) + + runtime.cxx_library( + name = name, + srcs = [ + ":{}[{}.cpp]".format(genrule_name, name), + ], + define_static_target = False, + # Static initialization is used to register shaders to the global shader registry, + # therefore link_whole must be True to make sure unused symbols are not discarded. + # @lint-ignore BUCKLINT: Avoid `link_whole=True` + link_whole = True, + # Define a soname that can be used for dynamic loading in Java, Python, etc. + soname = "lib{}.$(ext)".format(name), + exported_deps = [ + "//caffe2:torch_vulkan_api", + ], + ) + +def define_common_targets(is_fbcode = False): runtime.genrule( name = "gen_vk_delegate_schema", srcs = [ @@ -38,6 +87,21 @@ def define_common_targets(): ], ) + runtime.filegroup( + name = "vulkan_graph_runtime_shaders", + srcs = native.glob([ + "runtime/graph/ops/glsl/*", + ]), + ) + + vulkan_spv_shader_lib( + name = "vulkan_graph_runtime_shaderlib", + spv_filegroups = { + ":vulkan_graph_runtime_shaders": "runtime/graph/ops/glsl", + }, + is_fbcode = is_fbcode, + ) + runtime.cxx_library( name = "vulkan_graph_runtime", srcs = native.glob([ @@ -53,10 +117,15 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], exported_deps = [ - "//caffe2:torch_vulkan_ops", - "//caffe2:torch_vulkan_spv", + ":vulkan_graph_runtime_shaderlib", ], define_static_target = False, + # Static initialization is used to register operators to the global operator registry, + # therefore link_whole must be True to make sure unused symbols are not discarded. + # @lint-ignore BUCKLINT: Avoid `link_whole=True` + link_whole = True, + # Define an soname that can be used for dynamic loading in Java, Python, etc. + soname = "libvulkan_graph_runtime.$(ext)", ) runtime.cxx_library( @@ -77,6 +146,7 @@ def define_common_targets(): ":vk_delegate_schema", ":vulkan_graph_runtime", "//executorch/runtime/backend:interface", + "//executorch/runtime/core/exec_aten/util:tensor_util", ], define_static_target = False, # VulkanBackend.cpp needs to compile with executor as whole diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml new file mode 100644 index 0000000000..e6b4ca2cca --- /dev/null +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -0,0 +1,62 @@ +binary_op_nobroadcast__test: + parameter_names_with_default_values: + DTYPE: float + OPERATOR: X + Y + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: binary_add_nobroadcast__test + OPERATOR: X + Y + - NAME: binary_sub_nobroadcast__test + OPERATOR: X - Y + - NAME: binary_mul_nobroadcast__test + OPERATOR: X * Y + - NAME: binary_div_nobroadcast__test + OPERATOR: X / Y + - NAME: binary_pow_nobroadcast__test + OPERATOR: pow(X, Y) + +fill_texture__test: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: fill_texture__test + +image_to_nchw__test: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: image3d_to_nchw__test_C_packed + +nchw_to_image__test: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: nchw_to_image3d__test_C_packed diff --git a/backends/vulkan/test/glsl/binary_op_nobroadcast__test.glsl b/backends/vulkan/test/glsl/binary_op_nobroadcast__test.glsl new file mode 100644 index 0000000000..f5e5d6b4e4 --- /dev/null +++ b/backends/vulkan/test/glsl/binary_op_nobroadcast__test.glsl @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +// clang-format off +#define PRECISION ${PRECISION} + +#define OP(X, Y) ${OPERATOR} +// clang-format on + +layout(std430) buffer; + +// clang-format off +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D image_out; +// clang-format on +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents { + uvec4 data; +} +out_extents; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + return; + } + + vec4 in_texel = texelFetch(image_in, pos, 0); + vec4 other_texel = texelFetch(image_other, pos, 0); + + imageStore(image_out, pos, OP(in_texel, other_texel)); +} diff --git a/backends/vulkan/test/glsl/fill_texture__test.glsl b/backends/vulkan/test/glsl/fill_texture__test.glsl new file mode 100644 index 0000000000..fafad11d49 --- /dev/null +++ b/backends/vulkan/test/glsl/fill_texture__test.glsl @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +// clang-format off +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} uOutput; +// clang-format on +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec3 size; + int fill; + vec4 vals; +} params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, params.size))) { + return; + } + + imageStore(uOutput, pos, params.vals); +} diff --git a/backends/vulkan/test/glsl/image_to_nchw__test.glsl b/backends/vulkan/test/glsl/image_to_nchw__test.glsl new file mode 100644 index 0000000000..b5563b080f --- /dev/null +++ b/backends/vulkan/test/glsl/image_to_nchw__test.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +// clang-format off +#define PRECISION ${PRECISION} +// clang-format on + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_out; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const ${VEC4_T[DTYPE]} intex = texelFetch(image_in, pos, 0); + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + if (coord.z < cpu_sizes.data.z) { + buffer_out.data[buf_indices.x] = intex.x; + } + if (coord.z + 1 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.y] = intex.y; + } + if (coord.z + 2 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.z] = intex.z; + } + if (coord.z + 3 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.w] = intex.w; + } +} diff --git a/backends/vulkan/test/glsl/indexing_utils.h b/backends/vulkan/test/glsl/indexing_utils.h new file mode 100644 index 0000000000..d3f005c1ee --- /dev/null +++ b/backends/vulkan/test/glsl/indexing_utils.h @@ -0,0 +1,14 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#define POS_TO_COORD_CHANNELS_PACKED(pos, sizes) \ + ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) + +#define COORD_TO_BUFFER_IDX(coord, sizes) \ + coord.x + coord.y* sizes.x + coord.z* sizes.y* sizes.x + \ + coord.w* sizes.z* sizes.y* sizes.x; diff --git a/backends/vulkan/test/glsl/nchw_to_image__test.glsl b/backends/vulkan/test/glsl/nchw_to_image__test.glsl new file mode 100644 index 0000000000..1a41fd88d0 --- /dev/null +++ b/backends/vulkan/test/glsl/nchw_to_image__test.glsl @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +// clang-format off +#define PRECISION ${PRECISION} +// clang-format on + +#include "indexing_utils.h" + +layout(std430) buffer; + +// clang-format off +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +// clang-format on +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + ${T[DTYPE]} val_x = buffer_in.data[buf_indices.x]; + ${T[DTYPE]} val_y = buffer_in.data[buf_indices.y]; + ${T[DTYPE]} val_z = buffer_in.data[buf_indices.z]; + ${T[DTYPE]} val_w = buffer_in.data[buf_indices.w]; + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + if (coord.z + 3 >= cpu_sizes.data.z) { + ivec4 c_ind = ivec4(coord.z) + ivec4(0, 1, 2, 3); + vec4 valid_c = vec4(lessThan(c_ind, ivec4(cpu_sizes.data.z))); + texel = texel * valid_c; + } + + imageStore(image_out, ${GET_POS[NDIM]("pos")}, texel); +} diff --git a/backends/vulkan/test/glsl/test_shader.glsl b/backends/vulkan/test/glsl/test_shader.glsl new file mode 100644 index 0000000000..39edc92cc6 --- /dev/null +++ b/backends/vulkan/test/glsl/test_shader.glsl @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +#define PRECISION ${PRECISION} +#define FORMAT ${FORMAT} + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec4 intex = texelFetch(uInput, pos, 0); + imageStore( + uOutput, + pos, + intex + 5); + } +} diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8a491497c3..d18ca74588 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -14,7 +14,7 @@ from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.exir import EdgeProgramManager, to_edge -from torch.export import export, ExportedProgram +from torch.export import Dim, export, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -54,13 +54,17 @@ def lower_module_and_test_output( sample_inputs: Tuple[torch.Tensor], atol=1e-03, rtol=1e-01, + dynamic_shapes=None, + test_inputs=None, ): """ Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with the given sample inputs. It then runs the lowered module and compares its outputs with the outputs of the eager module. """ - program: ExportedProgram = export(model, sample_inputs) + program: ExportedProgram = export( + model, sample_inputs, dynamic_shapes=dynamic_shapes + ) edge_program: EdgeProgramManager = to_edge(program) edge_program = edge_program.to_backend(VulkanPartitioner()) @@ -80,6 +84,19 @@ def lower_module_and_test_output( self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol) + if test_inputs is not None: + for test_input in test_inputs: + # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. + test_inputs_flattened, _ = tree_flatten(test_input) + model_output = executorch_module.run_method( + "forward", tuple(test_inputs_flattened) + ) + ref_output = model(*test_input) + + self.assert_outputs_equal( + model_output, ref_output, atol=atol, rtol=rtol + ) + def test_vulkan_backend_add(self): # This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts class AddModule(torch.nn.Module): @@ -93,12 +110,12 @@ def forward(self, x, y): return z add_module = AddModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(add_module, model_inputs) + self.lower_module_and_test_output(add_module, sample_inputs) def test_vulkan_backend_internal_data(self): class InternalDataModule(torch.nn.Module): @@ -107,19 +124,19 @@ def __init__(self): self.weight = torch.rand(size=(2, 3), dtype=torch.float32) def forward(self, x, y): - z = x + y - z = z + x + z = torch.add(x, y, alpha=2) + z = torch.add(x, y, alpha=3.14) z = z + x z = z + self.weight return z internal_data_module = InternalDataModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(internal_data_module, model_inputs) + self.lower_module_and_test_output(internal_data_module, sample_inputs) def test_vulkan_backend_sub(self): class SubModule(torch.nn.Module): @@ -127,18 +144,18 @@ def __init__(self): super().__init__() def forward(self, x, y): - z = x - y - z = z - x + z = torch.sub(x, y, alpha=2) + z = torch.sub(z, x, alpha=3.14) z = z - x return z sub_module = SubModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(sub_module, model_inputs) + self.lower_module_and_test_output(sub_module, sample_inputs) def test_vulkan_backend_mul(self): class MulModule(torch.nn.Module): @@ -152,12 +169,12 @@ def forward(self, x, y): return z mul_module = MulModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(mul_module, model_inputs) + self.lower_module_and_test_output(mul_module, sample_inputs) def test_vulkan_backend_div(self): class DivModule(torch.nn.Module): @@ -171,12 +188,12 @@ def forward(self, x, y): return z div_module = DivModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(div_module, model_inputs) + self.lower_module_and_test_output(div_module, sample_inputs) def test_vulkan_backend_arithmetic(self): class ArithmeticModule(torch.nn.Module): @@ -192,12 +209,12 @@ def forward(self, x, y): return z arithmetic_module = ArithmeticModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(arithmetic_module, model_inputs) + self.lower_module_and_test_output(arithmetic_module, sample_inputs) def test_vulkan_backend_floor_div(self): class FloorDivModule(torch.nn.Module): @@ -209,14 +226,14 @@ def forward(self, x, y): return z floor_div_module = FloorDivModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32) * 10.0, torch.rand(size=(2, 3), dtype=torch.float32) + 1.0, ) # absolute tolerance is 1 because of flooring self.lower_module_and_test_output( - floor_div_module, model_inputs, atol=1.0 + 1e-03 + floor_div_module, sample_inputs, atol=1.0 + 1e-03 ) def test_vulkan_backend_pow(self): @@ -229,12 +246,12 @@ def forward(self, x, y): return z pow_module = PowModule() - model_inputs = ( + sample_inputs = ( torch.rand(size=(2, 3), dtype=torch.float32), torch.rand(size=(2, 3), dtype=torch.float32), ) - self.lower_module_and_test_output(pow_module, model_inputs) + self.lower_module_and_test_output(pow_module, sample_inputs) def test_vulkan_backend_partial(self): class SimpleModel(torch.nn.Module): @@ -248,6 +265,41 @@ def forward(self, x): return self.linear(x + self.offset_1) - self.offset_2 model = SimpleModel() - model_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),) + sample_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),) + + self.lower_module_and_test_output(model, sample_inputs) + + def test_vulkan_backend_partial_dynamic_shapes(self): + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.branch1 = torch.nn.Sequential( + torch.nn.Linear(64, 64), torch.nn.ReLU() + ) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer_1 = torch.ones((1, 64)) * 0.5 + self.buffer_2 = torch.ones((1, 64)) * 1.4 + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer_1 + out2) * self.buffer_2 + + model = SimpleModel() + sample_inputs = (torch.randn(32, 64), torch.randn(32, 128)) + batch = Dim("batch", max=32) + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + test_inputs = [ + (torch.randn(15, 64), torch.randn(15, 128)), + (torch.randn(6, 64), torch.randn(6, 128)), + (torch.randn(30, 64), torch.randn(30, 128)), + (torch.randn(20, 64), torch.randn(20, 128)), + (torch.randn(19, 64), torch.randn(19, 128)), + ] - self.lower_module_and_test_output(model, model_inputs) + self.lower_module_and_test_output( + model, sample_inputs, dynamic_shapes=dynamic_shapes, test_inputs=test_inputs + ) diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp new file mode 100644 index 0000000000..45b91ce99b --- /dev/null +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +// +// Operator Recording Functions +// + +void record_nchw_to_buffer_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst) { + uint32_t buf_len = api::utils::safe_downcast(v_dst.gpu_numel()); + api::utils::uvec3 global_size = {buf_len, 1u, 1u}; + api::utils::uvec3 local_size = {32u, 1u, 1u}; + + api::UniformParamsBuffer cpu_buffer_metadata( + context, v_dst.get_cpu_buffer_metadata()); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + VK_KERNEL(buffer_to_buffer), + pipeline_barrier, + global_size, + local_size, + VK_NULL_HANDLE, + v_dst.buffer( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_dst.buffer_metadata(), + src_buffer, + cpu_buffer_metadata.buffer()); +} + +bool record_buffer_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer) { + uint32_t buf_len = api::utils::safe_downcast(v_src.numel()); + api::utils::uvec3 global_size = {buf_len, 1u, 1u}; + api::utils::uvec3 local_size = {4u, 1u, 1u}; + + api::UniformParamsBuffer cpu_buffer_metadata( + context, v_src.get_cpu_buffer_metadata()); + api::PipelineBarrier pipeline_barrier{}; + + return context->submit_compute_job( + VK_KERNEL(buffer_to_buffer), + pipeline_barrier, + global_size, + local_size, + VK_NULL_HANDLE, + dst_buffer, + cpu_buffer_metadata.buffer(), + v_src.buffer( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_src.buffer_metadata()); +} + +void record_nchw_to_image_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst) { + api::PipelineBarrier pipeline_barrier{}; + api::ShaderInfo compute_shader = + VK_KERNEL(nchw_to_image3d__test_C_packed_half); + if (v_dst.image().format() == VK_FORMAT_R32G32B32A32_SFLOAT) { + compute_shader = VK_KERNEL(nchw_to_image3d__test_C_packed_float); + } + context->submit_compute_job( + compute_shader, + pipeline_barrier, + v_dst.virtual_extents(), + adaptive_work_group_size(v_dst.virtual_extents()), + VK_NULL_HANDLE, + v_dst.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + src_buffer, + v_dst.gpu_sizes_ubo()->buffer(), + v_dst.cpu_sizes_ubo()->buffer()); +} + +void record_image_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer) { + api::ShaderInfo compute_shader = + VK_KERNEL(image3d_to_nchw__test_C_packed_half); + if (v_src.image().format() == VK_FORMAT_R32G32B32A32_SFLOAT) { + compute_shader = VK_KERNEL(image3d_to_nchw__test_C_packed_float); + } + api::PipelineBarrier pipeline_barrier{}; + context->submit_compute_job( + compute_shader, + pipeline_barrier, + v_src.virtual_extents(), + adaptive_work_group_size(v_src.virtual_extents()), + VK_NULL_HANDLE, + v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE), + dst_buffer, + v_src.gpu_sizes_ubo()->buffer(), + v_src.cpu_sizes_ubo()->buffer()); +} + +void record_binary_op( + api::Context* const context, + const std::string& op_name, + vTensor& v_in1, + vTensor& v_in2, + vTensor& v_dst) { + std::stringstream kernel_name; + kernel_name << "binary_" << op_name << "_nobroadcast__test"; + apply_dtype_suffix(kernel_name, v_dst); + + api::PipelineBarrier pipeline_barrier{}; + context->submit_compute_job( + VK_KERNEL_FROM_STR(kernel_name.str()), + pipeline_barrier, + v_dst.virtual_extents(), + adaptive_work_group_size(v_dst.virtual_extents()), + VK_NULL_HANDLE, + v_dst.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_in1.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_in2.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_dst.extents_ubo()->buffer()); +} + +void execute_and_check_add( + vTensor& a, + vTensor& b, + vTensor& c, + float a_val, + float b_val) { + // Add shader kernel + api::ShaderInfo kernel = VK_KERNEL(binary_add_nobroadcast__test_half); + if (c.image().format() == VK_FORMAT_R32G32B32A32_SFLOAT) { + kernel = VK_KERNEL(nchw_to_image3d__test_C_packed_float); + } + + // Fill input tensors + fill_vtensor(a, a_val); + fill_vtensor(b, b_val); + + // a + b = c + record_binary_op(api::context(), "add", a, b, c); + + // Extract output tensor + std::vector data_out = extract_vtensor(c); + + // Check output + for (const auto& d : data_out) { + EXPECT_TRUE(d == (a_val + b_val)); + } +} + +// +// Input & Output Utilities +// + +void fill_vtensor(vTensor& vten, std::vector& data) { + api::StorageBuffer staging_buffer(api::context(), api::kFloat, data.size()); + + copy_ptr_to_staging(data.data(), staging_buffer, vten.gpu_nbytes()); + + if (vten.storage_type() == api::StorageType::BUFFER) { + record_nchw_to_buffer_op(api::context(), staging_buffer.buffer(), vten); + } else { + record_nchw_to_image_op(api::context(), staging_buffer.buffer(), vten); + } +} + +void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val) { + std::vector data(graph.get_val(idx.value).toTensor().gpu_numel()); + std::fill(data.begin(), data.end(), val); + + graph.copy_into_staging(idx.staging, data.data(), data.size()); +} + +void extract_vtensor(vTensor& vten, std::vector& data) { + api::StorageBuffer staging_buffer( + api::context(), api::kFloat, vten.gpu_numel()); + + if (vten.storage_type() == api::StorageType::BUFFER) { + record_buffer_to_nchw_op(api::context(), vten, staging_buffer.buffer()); + } else { + record_image_to_nchw_op(api::context(), vten, staging_buffer.buffer()); + } + + api::VulkanFence fence = api::context()->fences().get_fence(); + api::context()->submit_cmd_to_gpu(fence.get_submit_handle()); + fence.wait(); + + copy_staging_to_ptr(staging_buffer, data.data(), vten.gpu_nbytes()); +} + +// +// Context Management +// + +void submit_to_gpu() { + api::VulkanFence fence = api::context()->fences().get_fence(); + api::context()->submit_cmd_to_gpu(fence.get_submit_handle()); + fence.wait(); +} + +api::MemoryAllocation allocate_memory_for(const vTensor& vten) { + return api::context()->adapter_ptr()->vma().create_allocation( + vten.get_memory_requirements(), vten.get_allocation_create_info()); +} + +VmaTotalStatistics get_vma_stats() { + return api::context()->adapter_ptr()->vma().get_memory_statistics(); +} + +size_t get_vma_allocation_count() { + return get_vma_stats().total.statistics.allocationCount; +} diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h new file mode 100644 index 0000000000..8c946107ef --- /dev/null +++ b/backends/vulkan/test/utils/test_utils.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include +#include + +using namespace at::native::vulkan; + +#define CREATE_FLOAT_TEXTURE(sizes, allocate_memory) \ + vTensor( \ + api::context(), \ + sizes, \ + api::kFloat, \ + api::StorageType::TEXTURE_3D, \ + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ + allocate_memory); + +#define CREATE_FLOAT_BUFFER(sizes, allocate_memory) \ + vTensor( \ + api::context(), \ + sizes, \ + api::kFloat, \ + api::StorageType::BUFFER, \ + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ + allocate_memory); + +#define DEFINE_STAGING_BUFFER_AND_RECORD_TO_GPU_FOR(tensor) \ + api::StorageBuffer staging_buffer_##tensor( \ + api::context(), api::kFloat, tensor.gpu_numel()); \ + record_nchw_to_image_op( \ + api::context(), staging_buffer_##tensor.buffer(), tensor); + +#define DEFINE_STAGING_BUFFER_AND_RECORD_FROM_GPU_FOR(tensor) \ + api::StorageBuffer staging_buffer_##tensor( \ + api::context(), api::kFloat, tensor.gpu_numel()); \ + record_image_to_nchw_op( \ + api::context(), tensor, staging_buffer_##tensor.buffer()); + +// +// Operator Recording +// + +void record_nchw_to_buffer_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst); + +bool record_buffer_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer); + +void record_nchw_to_image_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst); + +void record_image_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer); + +void record_binary_op( + api::Context* const context, + const std::string& op_name, + vTensor& v_in1, + vTensor& v_in2, + vTensor& v_dst); + +void execute_and_check_add( + vTensor& a, + vTensor& b, + vTensor& c, + float a_val, + float b_val); + +// +// Input & Output Utilities +// + +inline void +fill_staging(api::StorageBuffer& staging, float val, int numel = -1) { + if (numel < 0) { + numel = staging.numel(); + } + std::vector data(numel); + std::fill(data.begin(), data.end(), val); + copy_ptr_to_staging(data.data(), staging, sizeof(float) * numel); +} + +void fill_vtensor(vTensor& vten, std::vector& data); + +inline void fill_vtensor(vTensor& vten, float val) { + std::vector vten_data(vten.gpu_numel()); + std::fill(vten_data.begin(), vten_data.end(), val); + + fill_vtensor(vten, vten_data); +} + +void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val); + +void extract_vtensor(vTensor& vten, std::vector& data); + +inline std::vector extract_vtensor(vTensor& vten) { + std::vector data_out(vten.gpu_numel()); + extract_vtensor(vten, data_out); + return data_out; +} + +inline void +check_staging_buffer(api::StorageBuffer& staging, float val, int numel = -1) { + if (numel < 0) { + numel = staging.numel(); + } + std::vector data(numel); + copy_staging_to_ptr(staging, data.data(), sizeof(float) * numel); + + for (const auto& d : data) { + EXPECT_TRUE(d == val); + } +} + +// +// Context Management +// + +void submit_to_gpu(); + +api::MemoryAllocation allocate_memory_for(const vTensor& vten); + +VmaTotalStatistics get_vma_stats(); + +size_t get_vma_allocation_count(); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 0692d8c709..eaaca78749 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -10,136 +10,18 @@ #include -#include -#include +#include -#include -#include +#include -using namespace at::native::vulkan; - -// -// Utilities -// - -#define CREATE_FLOAT_TEXTURE(sizes, allocate_memory) \ - vTensor( \ - api::context(), \ - sizes, \ - api::kFloat, \ - api::StorageType::TEXTURE_3D, \ - api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ - allocate_memory); - -#define CREATE_FLOAT_BUFFER(sizes, allocate_memory) \ - vTensor( \ - api::context(), \ - sizes, \ - api::kFloat, \ - api::StorageType::BUFFER, \ - api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ - allocate_memory); - -void fill_vtensor(vTensor& vten, std::vector& data) { - api::StorageBuffer staging_buffer(api::context(), api::kFloat, data.size()); - - copy_ptr_to_staging(data.data(), staging_buffer, vten.gpu_nbytes()); - - if (vten.storage_type() == api::StorageType::BUFFER) { - packing::record_nchw_to_buffer_op( - api::context(), staging_buffer.buffer(), vten, {}, VK_NULL_HANDLE); - } else { - api::ShaderInfo compute_shader = packing::get_nchw_to_image_shader(vten); - packing::record_nchw_to_image_op( - api::context(), - compute_shader, - staging_buffer.buffer(), - vten, - {}, - VK_NULL_HANDLE); - } -} - -void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val) { - std::vector data(graph.get_val(idx.value).toTensor().gpu_numel()); - std::fill(data.begin(), data.end(), val); - - graph.copy_into_staging(idx.staging, data.data(), data.size()); -} - -void extract_vtensor(vTensor& vten, std::vector& data) { - api::StorageBuffer staging_buffer( - api::context(), api::kFloat, vten.gpu_numel()); - - if (vten.storage_type() == api::StorageType::BUFFER) { - packing::record_buffer_to_nchw_op( - api::context(), vten, staging_buffer.buffer(), {}, VK_NULL_HANDLE); - } else { - api::ShaderInfo compute_shader = packing::get_image_to_nchw_shader(vten); - packing::record_image_to_nchw_op( - api::context(), - compute_shader, - vten, - staging_buffer.buffer(), - {}, - VK_NULL_HANDLE); - } - - api::VulkanFence fence = api::context()->fences().get_fence(); - api::context()->submit_cmd_to_gpu(fence.get_submit_handle()); - fence.wait(); - - copy_staging_to_ptr(staging_buffer, data.data(), vten.gpu_nbytes()); -} - -api::MemoryAllocation allocate_memory_for(const vTensor& vten) { - return api::context()->adapter_ptr()->vma().create_allocation( - vten.get_memory_requirements(), vten.get_allocation_create_info()); -} - -VmaTotalStatistics get_vma_stats() { - return api::context()->adapter_ptr()->vma().get_memory_statistics(); -} - -size_t get_vma_allocation_count() { - return get_vma_stats().total.statistics.allocationCount; -} - -GraphConfig generate_graph_config() { - const uint32_t submit_frequency = UINT32_MAX; - - const api::CommandPoolConfig cmd_config{ - 4u, // cmdPoolInitialSize - 2u, // cmdPoolBatchSize - }; - - const api::DescriptorPoolConfig descriptor_pool_config{ - 1024u, // descriptorPoolMaxSets - 1024u, // descriptorUniformBufferCount - 1024u, // descriptorStorageBufferCount - 1024u, // descriptorCombinedSamplerCount - 1024u, // descriptorStorageImageCount - 32u, // descriptorPileSizes - }; +#include - const api::QueryPoolConfig query_pool_config{}; +#include - const api::ContextConfig context_config{ - submit_frequency, // cmdSubmitFrequency - cmd_config, // cmdPoolConfig - descriptor_pool_config, // descriptorPoolConfig - query_pool_config, // queryPoolConfig - }; - - const GraphConfig graph_config{ - context_config, - }; - - return graph_config; -} +using namespace at::native::vulkan; // -// Test Wrapper +// Compute API Tests // class VulkanComputeAPITest : public ::testing::Test { @@ -157,63 +39,67 @@ class VulkanComputeAPITest : public ::testing::Test { } }; -// -// Compute API Tests -// +TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) { + // Try to get shader from custom shader library + const api::ShaderInfo& kernel = VK_KERNEL(test_shader); -TEST_F(VulkanComputeAPITest, buffer_copy_sanity_check) { - // Simple test that copies data into a and reads from a - std::vector sizes = {4, 4, 1}; - vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ true); - - // Input data - std::vector data_in(a.gpu_numel()); - std::fill(data_in.begin(), data_in.end(), 2.524f); + EXPECT_TRUE(kernel.kernel_name == "test_shader"); +} - // Fill input tensor - fill_vtensor(a, data_in); +TEST_F(VulkanComputeAPITest, update_params_between_submit) { + api::context()->set_cmd(/*reusable = */ true); + std::vector sizes = {4, 4, 2}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); - // Read back data - std::vector data_out(a.gpu_numel()); - extract_vtensor(a, data_out); + std::stringstream kernel_name; + kernel_name << "fill_texture__test"; + apply_dtype_suffix(kernel_name, a); - // Check output - for (const auto& d : data_out) { - EXPECT_TRUE(d == 2.524f); - } -} + struct Params final { + api::utils::ivec3 size; + int32_t fill; + api::utils::vec4 values; + }; -TEST_F(VulkanComputeAPITest, buffer_deferred_allocation_test) { - // Same as buffer_copy_sanity_check, but defers memory allocation + Params block{ + {2, 4, 1}, + 0, + {5.0, 5.0, 5.0, 5.0}, + }; - std::vector sizes = {4, 4, 1}; - vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ false); + api::UniformParamsBuffer params(api::context(), block); - // For buffer storage, a small uniform buffer is allocated containing size and - // stride data, which is why the check is for 1 allocation below. - EXPECT_TRUE(get_vma_allocation_count() == 1); + { + api::PipelineBarrier pipeline_barrier{}; + api::context()->submit_compute_job( + VK_KERNEL_FROM_STR(kernel_name.str()), + pipeline_barrier, + {4, 4, 4}, + {4, 4, 4}, + VK_NULL_HANDLE, + a.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + params.buffer()); + } - // Input data - std::vector data_in(a.gpu_numel()); - std::fill(data_in.begin(), data_in.end(), 1.234f); + api::StorageBuffer staging_buffer(api::context(), api::kFloat, a.gpu_numel()); + record_image_to_nchw_op(api::context(), a, staging_buffer.buffer()); - // Allocate memory at the last possible opportunity - api::MemoryAllocation a_mem = allocate_memory_for(a); - a.buffer().bind_allocation(a_mem); + submit_to_gpu(); + check_staging_buffer(staging_buffer, 5.0f); - EXPECT_TRUE(get_vma_allocation_count() == 2); - - // Fill input tensor - fill_vtensor(a, data_in); + Params new_block{ + {2, 4, 1}, + 0, + {4.0, 4.0, 4.0, 4.0}, + }; - // Read back data - std::vector data_out(a.gpu_numel()); - extract_vtensor(a, data_out); + params.update(new_block); - // Check output - for (const auto& d : data_out) { - EXPECT_TRUE(d == 1.234f); - } + submit_to_gpu(); + check_staging_buffer(staging_buffer, 4.0f); } TEST_F(VulkanComputeAPITest, texture_add_sanity_check) { @@ -224,25 +110,15 @@ TEST_F(VulkanComputeAPITest, texture_add_sanity_check) { vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); - // Input data - std::vector data_a(a.gpu_numel()); - std::fill(data_a.begin(), data_a.end(), 2.5f); - std::vector data_b(b.gpu_numel()); - std::fill(data_b.begin(), data_b.end(), 1.5f); - - // Add shader kernel - api::ShaderInfo kernel = arithmetic::get_shader(arithmetic::OpType::ADD); - // Fill input tensors - fill_vtensor(a, data_a); - fill_vtensor(b, data_b); + fill_vtensor(a, 2.5f); + fill_vtensor(b, 1.5f); // a + b -> c - arithmetic::record_op(api::context(), kernel, a, b, c, 1.0f); + record_binary_op(api::context(), "add", a, b, c); // Extract output tensor - std::vector data_out(c.gpu_numel()); - extract_vtensor(c, data_out); + std::vector data_out = extract_vtensor(c); // Check output for (const auto& d : data_out) { @@ -267,8 +143,6 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { std::vector data_b(b.gpu_numel()); std::fill(data_b.begin(), data_b.end(), 1.5f); - api::ShaderInfo kernel = arithmetic::get_shader(arithmetic::OpType::ADD); - // Allocate memory at the last possible opportunity api::MemoryAllocation a_mem = allocate_memory_for(a); a.image().bind_allocation(a_mem); @@ -283,7 +157,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { fill_vtensor(a, data_a); fill_vtensor(b, data_b); - arithmetic::record_op(api::context(), kernel, a, b, c, 1.0f); + record_binary_op(api::context(), "add", a, b, c); std::vector data_c(c.gpu_numel()); extract_vtensor(c, data_c); @@ -332,21 +206,18 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) { std::vector data_d(b.gpu_numel()); std::fill(data_d.begin(), data_d.end(), 1.0f); - // Get shader kernel for add - api::ShaderInfo kernel = arithmetic::get_shader(arithmetic::OpType::ADD); - // First, fill a and b with data fill_vtensor(a, data_a); fill_vtensor(b, data_b); // a + b -> c - arithmetic::record_op(api::context(), kernel, a, b, c, 1.0f); + record_binary_op(api::context(), "add", a, b, c); // Now d can be filled with data fill_vtensor(d, data_d); // c + d -> e - arithmetic::record_op(api::context(), kernel, c, d, e, 1.0f); + record_binary_op(api::context(), "add", c, d, e); // Extract data from e std::vector data_e(e.gpu_numel()); @@ -408,6 +279,103 @@ TEST_F(VulkanComputeAPITest, use_non_bound_textures_fails) { EXPECT_THROW(fill_vtensor(a, data_a), api::Error); } +TEST_F(VulkanComputeAPITest, tensor_reallocation_test) { + std::vector sizes = {4, 4, 1}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + + execute_and_check_add(a, b, c, 3.0f, 5.0f); + + // Redo with new sizes + std::vector new_sizes = {4, 6, 3}; + a.reallocate(new_sizes); + b.reallocate(new_sizes); + c.reallocate(new_sizes); + + // Flush everything + api::context()->flush(); + + execute_and_check_add(a, b, c, 12.0f, 10.0f); +} + +TEST_F( + VulkanComputeAPITest, + tensor_reallocation_with_deferred_allocation_test) { + std::vector sizes = {8, 8, 8}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); + vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); + vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); + + api::MemoryAllocation a_mem = allocate_memory_for(a); + a.image().bind_allocation(a_mem); + api::MemoryAllocation b_mem = allocate_memory_for(b); + b.image().bind_allocation(b_mem); + api::MemoryAllocation c_mem = allocate_memory_for(c); + c.image().bind_allocation(c_mem); + + execute_and_check_add(a, b, c, 4.0f, 8.0f); + + std::vector> new_sizes_list = { + {4, 3, 5}, {4, 1, 7}, {8, 3, 2}, {8, 7, 2}}; + + for (auto& new_sizes : new_sizes_list) { + // Redo with new sizes + a.reallocate(new_sizes); + b.reallocate(new_sizes); + c.reallocate(new_sizes); + + // Flush everything + api::context()->flush(); + + a.image().bind_allocation(a_mem); + b.image().bind_allocation(b_mem); + c.image().bind_allocation(c_mem); + + execute_and_check_add( + a, b, c, float(new_sizes[1] + 4.5f), float(new_sizes[2] + 13.0f)); + } +} + +TEST_F(VulkanComputeAPITest, texture_virtual_resize) { + api::context()->set_cmd(/*reusable = */ true); + std::vector sizes = {8, 12, 12}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + + DEFINE_STAGING_BUFFER_AND_RECORD_TO_GPU_FOR(a) + DEFINE_STAGING_BUFFER_AND_RECORD_TO_GPU_FOR(b) + + fill_staging(staging_buffer_a, 11.5f); + fill_staging(staging_buffer_b, 12.5f); + + record_binary_op(api::context(), "add", a, b, c); + + DEFINE_STAGING_BUFFER_AND_RECORD_FROM_GPU_FOR(c) + + submit_to_gpu(); + check_staging_buffer(staging_buffer_c, 24.0f); + + std::vector> new_sizes_list = { + {4, 2, 4}, {4, 3, 6}, {8, 12, 12}, {8, 1, 1}, {8, 11, 10}}; + + for (auto& new_sizes : new_sizes_list) { + a.virtual_resize(new_sizes); + b.virtual_resize(new_sizes); + c.virtual_resize(new_sizes); + + fill_staging(staging_buffer_a, float(new_sizes[1] + 1.5f), a.gpu_numel()); + fill_staging(staging_buffer_b, float(new_sizes[2] + 55.0f), b.gpu_numel()); + + submit_to_gpu(); + check_staging_buffer( + staging_buffer_c, + float(new_sizes[1] + new_sizes[2] + 56.5f), + c.gpu_numel()); + } +} + // // Compute Graph Tests // @@ -417,12 +385,66 @@ TEST_F(VulkanComputeAPITest, use_non_bound_textures_fails) { graph.get_val(name.value).toTensor().gpu_numel()); \ graph.copy_from_staging(name.staging, data_##name.data(), data_##name.size()); +TEST(VulkanComputeGraphTest, test_values_scalars) { + GraphConfig config; + ComputeGraph graph(config); + + ValueRef idx; + + idx = graph.add_scalar(4); + EXPECT_TRUE(graph.get_val(idx).toInt() == 4); + + idx = graph.add_scalar(5.5f); + EXPECT_TRUE(graph.get_val(idx).toDouble() == 5.5f); +} + +TEST(VulkanComputeGraphTest, test_values_scalar_list_inplace_constructed) { + GraphConfig config; + ComputeGraph graph(config); + + ValueRef idx = graph.add_scalar_list({1, 2, 3, 4}); + std::vector& arr = graph.get_val(idx).toIntList(); + EXPECT_TRUE(arr.size() == 4); + for (int i = 0; i < 4; i++) { + EXPECT_TRUE(arr[i] == i + 1); + } +} + +TEST(VulkanComputeGraphTest, test_values_scalar_list_outside_constructed) { + GraphConfig config; + ComputeGraph graph(config); + + ValueRef idx; + { + std::vector data = {5.0, 4.0, 3.0, 2.0, 1.0}; + idx = graph.add_scalar_list(std::move(data)); + } + std::vector& arr = graph.get_val(idx).toDoubleList(); + EXPECT_TRUE(arr.size() == 5); + for (int i = 0; i < 5; i++) { + EXPECT_TRUE(arr[i] == (5 - i)); + } +} + +TEST(VulkanComputeGraphTest, test_values_string) { + GraphConfig config; + ComputeGraph graph(config); + + ValueRef idx; + { + std::string data = "hello, world"; + idx = graph.add_string(std::move(data)); + } + std::string& stored = graph.get_val(idx).toString(); + EXPECT_TRUE(stored == "hello, world"); +} + TEST(VulkanComputeGraphTest, test_simple_graph) { - GraphConfig config = generate_graph_config(); + GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {8, 64, 124}; + std::vector size_small = {8, 1, 124}; // Build graph @@ -431,10 +453,14 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { IOValueRef out = {}; - out.value = add_arithmetic_node(graph, a.value, b.value, 1.0, VK_KERNEL(add)); + out.value = graph.add_tensor(size_big, api::kFloat); + + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, out.value}); out.staging = graph.set_output_tensor(out.value); + graph.prepare(); graph.encode_execute(); // Run graph @@ -464,11 +490,11 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { ValueRef name = graph.add_tensorref(sizes, api::kFloat, data_##name.data()); TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { - GraphConfig config = generate_graph_config(); + GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {8, 73, 62}; + std::vector size_small = {8, 73, 1}; CREATE_WEIGHT_TENSOR(w1, size_small, 3.5f); CREATE_WEIGHT_TENSOR(w2, size_small, 3.0f); @@ -477,13 +503,21 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { IOValueRef a = graph.add_input_tensor(size_big, api::kFloat); - ValueRef c = add_arithmetic_node(graph, a.value, w1, 1.0, VK_KERNEL(add)); - ValueRef e = add_arithmetic_node(graph, c, w2, 1.0, VK_KERNEL(mul)); + ValueRef c = graph.add_tensor(size_big, api::kFloat); + ValueRef e = graph.add_tensor(size_big, api::kFloat); + + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, w1, kDummyValueRef, c}); + + auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); + mulFn(graph, {c, w2, e}); IOValueRef out = {}; out.value = e; out.staging = graph.set_output_tensor(out.value); + graph.prepare(); + graph.encode_prepack(); graph.prepack(); @@ -508,12 +542,12 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { } } -TEST(VulkanComputeGraphTest, test_simple_shared_objects) { - GraphConfig config = generate_graph_config(); +TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { + GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {12, 64, 64}; + std::vector size_small = {12, 64, 64}; // Build graph @@ -526,59 +560,70 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 4); - // Allocation count will be 2: - // 1 staging buffer for each input tensor - EXPECT_TRUE(get_vma_allocation_count() == 2); + // Allocation count will be 6: + // 4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader + // 2: staging buffer for each input tensor + EXPECT_TRUE(get_vma_allocation_count() == 6); - ValueRef c = add_arithmetic_node( - graph, - a.value, - b.value, - 1.0, - VK_KERNEL(add), + ValueRef c = graph.add_tensor( + size_big, + api::kFloat, /*shared_object_idx = */ 6); + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, c}); + IOValueRef d = graph.add_input_tensor( size_small, api::kFloat, /*shared_object_idx = */ 2); - // Allocation count will be 4, two are new: - // 1 uniform buffer for arithmetic shader params - // 1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 4); - - ValueRef e = add_arithmetic_node( - graph, - c, - d.value, - 1.0, - VK_KERNEL(mul), + // Allocation count will be 11, 5 are new: + // 2: out.gpu_sizes_ubo(), alpha UBO for arithmetic shader + // 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader + // 1: staging buffer for the input tensor + EXPECT_TRUE(get_vma_allocation_count() == 11); + + ValueRef e = graph.add_tensor( + size_big, + api::kFloat, /*shared_object_idx = */ 4); + auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); + mulFn(graph, {c, d.value, e}); + IOValueRef out = {}; out.value = e; out.staging = graph.set_output_tensor(out.value); - // Allocation count will be 6, three are new: - // 1 uniform buffer for arithmetic shader params + // Allocation count will be 15, 4 are new: + // 1: alpha UBO for arithmetic shader + // 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader // 1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 6); + EXPECT_TRUE(get_vma_allocation_count() == 15); + graph.prepare(); graph.encode_execute(); - // Allocation count will be 13: - // 4 staging buffers for each I/O tensor - // 6 uniform buffers to store params for each shader dispatch - // 3 shared objects to back tensor memory - EXPECT_TRUE(get_vma_allocation_count() == 13); + // Allocation count will be 18, 3 are new: + // 3: shared memory allocations for tensors + EXPECT_TRUE(get_vma_allocation_count() == 18); // Run graph - for (float i = 4.0f; i < 30.0f; i += 7.0f) { - float val_a = i + 2.0f; - float val_b = i + 1.5f; - float val_d = i + 1.0f; + std::vector> new_sizes_list = { + {8, 44, 34}, {4, 13, 56}, {8, 12, 64}, {12, 55, 33}, {4, 54, 10}}; + + for (auto& new_sizes : new_sizes_list) { + graph.get_val(a.value).toTensor().virtual_resize(new_sizes); + graph.get_val(b.value).toTensor().virtual_resize(new_sizes); + graph.get_val(c).toTensor().virtual_resize(new_sizes); + graph.get_val(d.value).toTensor().virtual_resize(new_sizes); + graph.get_val(e).toTensor().virtual_resize(new_sizes); + + float val_a = new_sizes[1] + 4.0f; + float val_b = new_sizes[2] + 1.5f; + float val_d = new_sizes[0] + 2.0f; float val_out = (val_a + val_b) * val_d; fill_vtensor(graph, a, val_a); @@ -595,4 +640,91 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { EXPECT_TRUE(val == val_out); } } + + std::vector> new_sizes_list_2 = { + {8, 44, 34}, {4, 13, 56}, {8, 12, 64}, {12, 55, 33}, {4, 54, 10}}; + + for (auto& new_sizes : new_sizes_list_2) { + graph.resize_input(0, new_sizes); + graph.resize_input(1, new_sizes); + graph.resize_input(2, new_sizes); + graph.propagate_resize(); + + // Check output shape + EXPECT_TRUE(graph.get_val(out.value).toTensor().sizes() == new_sizes); + + float val_a = new_sizes[1] + 6.0f; + float val_b = new_sizes[2] + 2.5f; + float val_d = new_sizes[0] + 4.0f; + float val_out = (val_a + val_b) * val_d; + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + fill_vtensor(graph, d, val_d); + + // Execute graph + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + for (const auto& val : data_out) { + ASSERT_TRUE(val == val_out); + } + } +} + +TEST(VulkanComputeGraphTest, test_large_graph) { + GraphConfig config; + ComputeGraph graph(config); + + int64_t input_w = 256; + int64_t input_h = 256; + int64_t input_c = 8; + + std::vector size_big = {input_c, input_h, input_w}; + std::vector size_small = {input_c, input_h, 1}; + + // Build graph + + IOValueRef a = graph.add_input_tensor(size_big, api::kFloat, 2); + IOValueRef b = graph.add_input_tensor(size_small, api::kFloat, 4); + + ValueRef c = graph.add_tensor(size_big, api::kFloat, 6); + + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, c}); + + int n = 100; + + for (int i = 0; i < n; i++) { + addFn(graph, {c, b.value, kDummyValueRef, a.value}); + + addFn(graph, {a.value, b.value, kDummyValueRef, c}); + } + + IOValueRef out = {}; + out.value = c; + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + for (int i = 0; i < 10; i++) { + float val_a = 1.0f; + float val_b = 2.0f; + + float val_e = val_a + val_b * (2 * n + 1); + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + + graph.execute(); + + EXTRACT_TENSOR(out); + + for (const auto& val : data_out) { + EXPECT_TRUE(val == val_e); + } + } } diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 293d114e8d..91a85f15a1 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,8 +6,6 @@ from typing import final, List -import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, @@ -22,21 +20,15 @@ from executorch.exir.passes import MemoryPlanningPass, SpecPropPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass + from executorch.exir.program._program import _copy_module -from torch import dtype, float32 DEFAULT_DEBUG_HANDLE = 65535 @final class VulkanBackend(BackendDetails): - @staticmethod - def get_vk_datatype(torch_dtype: dtype) -> vk_graph_schema.VkDataType: - if torch_dtype == float32: - return vk_graph_schema.VkDataType.fp32 - else: - raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - @classmethod # pyre-ignore def preprocess( # noqa: C901 @@ -46,6 +38,7 @@ def preprocess( # noqa: C901 ) -> PreprocessResult: passes = [ SpecPropPass(), + ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass("greedy"), ] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index ef593e0609..c184f60f08 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -149,8 +149,8 @@ Error defineTensor( ValuePtr value, GraphPtr flatbuffer_graph, const uint8_t* constant_data_ptr, - XNNExecutor* executor, - MemoryAllocator* runtime_allocator) { + std::vector& input_ids, + std::vector& output_ids) { const fb_xnnpack::XNNTensorValue* tensor_value = nullptr; const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr; @@ -272,7 +272,6 @@ Error defineTensor( /*external_id=*/tensor_value->external_id(), /*flags=*/tensor_value->flags(), /*id_out=*/&float_id); - executor->addDynamicQinput(float_id); // Define dynamic conversion from float to qdint8 status = xnn_define_convert( @@ -391,10 +390,13 @@ Error defineTensor( // map serialized id to newly generated id remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id)); - // Append this external id to the arg list for execute(*args) to extract from - // as args[external_id] - if (tensor_value->external_id() != XNN_INVALID_VALUE_ID) { - executor->append_arg(tensor_value->external_id()); + + // Add external ids to either list of input or output ids + if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT) { + input_ids.push_back(tensor_value->external_id()); + } + if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) { + output_ids.push_back(tensor_value->external_id()); } return Error::Ok; @@ -1594,6 +1596,9 @@ __ET_NODISCARD Error XNNCompiler::compileModel( // Invalid ids do not need to be remapped remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID); + // External Ids for inputs and outputs + std::vector input_ids; + std::vector output_ids; Error err = Error::Ok; for (auto value : *flatbuffer_graph->xvalues()) { err = defineTensor( @@ -1602,8 +1607,8 @@ __ET_NODISCARD Error XNNCompiler::compileModel( value, flatbuffer_graph, constant_data, - executor, - runtime_allocator); + input_ids, + output_ids); if (err != Error::Ok) { return err; @@ -1635,47 +1640,10 @@ __ET_NODISCARD Error XNNCompiler::compileModel( "XNN Runtime creation failed with code: %s", xnn_status_to_string(status)); - executor->initialize(runtime_ptr); // NOLINT: runtime_ptr is non-null as - // error is checked above. - - // HACK FOR FC/BC this is only to support old dq_datatype - if (executor->qinputs_.size() > 0) { - // qinputs_ is only set when using the old dq linear path. At which point - // We need to overide the input_ids_ This workse based off the assumption - // old dqlinear path will be single node single input delegate - for (uint32_t id : executor->qinputs_) { - executor->input_ids_.emplace_back(id); - } - } else { - for (auto old_id : *flatbuffer_graph->input_ids()) { - executor->input_ids_.emplace_back(remapped_ids.at(old_id)); - } - } - // External ids need to be in order for wiring with args - std::sort(executor->input_ids_.begin(), executor->input_ids_.end()); - - for (auto old_id : *flatbuffer_graph->output_ids()) { - executor->output_ids_.emplace_back(remapped_ids.at(old_id)); - } - // External ids need to be in order for wiring with args - std::sort(executor->output_ids_.begin(), executor->output_ids_.end()); - - if (!executor->qinputs_.empty() && flatbuffer_graph->xnodes()->size() > 0 && - flatbuffer_graph->xnodes()->Get(0)->xnode_union_type() == - fb_xnnpack::XNodeUnion::XNNFullyConnected) { -#ifdef ENABLE_DYNAMIC_QUANTIZATION - // This delegate is for DQLinear which supports dynamic input shapes - if (executor->getNumInputs() < 1 || executor->getNumOutputs() != 1) { - ET_LOG( - Error, - "DQLinear should have at least one input and exactly one output"); - return Error::NotSupported; - } -#else - ET_LOG(Error, "DQ Linear is not supported"); - return Error::NotSupported; -#endif - } + err = executor->initialize( // NOLINT: runtime_ptr is non-null + runtime_ptr, + std::move(input_ids), + std::move(output_ids)); return err; }; diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index 592d9574d1..a5cd3b6737 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -13,34 +13,208 @@ namespace executor { namespace xnnpack { namespace delegate { -Error XNNExecutor::set_external_input( - uint32_t id, - Tensor* input, - struct XNNShape* shape) { - // TODO(T165403530): Test ensure accuracy for int64 --> float32 conversion - if (input->scalar_type() == ScalarType::Long) { - // Input data type is int64. However, XNNPACK doesn't support - // int64. This means that the data needs to be casted to float - // In order for XNNPACK to properly use it. - const int64_t* data_64 = input->const_data_ptr(); - float* data_f32 = input->mutable_data_ptr(); - for (int j = 0; j < input->numel(); j++) { - data_f32[j] = data_64[j]; +using Tensor = exec_aten::Tensor; +using ScalarType = exec_aten::ScalarType; +using SizesType = exec_aten::SizesType; + +/** + * Initializes the XNNExecutor with the runtime and given number of + * inputs/outputs externals_ is resized to the total number of inputs and + * outputs + */ +__ET_NODISCARD Error XNNExecutor::initialize( + xnn_runtime_t runtime, + std::vector&& input_ids, + std::vector&& output_ids) { + runtime_ = std::unique_ptr( + runtime, xnn_delete_runtime); + + auto error = profiler_.initialize(runtime); + if (error != Error::Ok) { + ET_LOG( + Error, + "Failed to start profiling: %u.", + static_cast(error)); + } + + // Initialize the external values for inputs and outputs + // mapping the executorch arg idx to external IDs + input_ids_ = std::move(input_ids); + std::sort(input_ids_.begin(), input_ids_.end()); + + output_ids_ = std::move(output_ids); + std::sort(output_ids_.begin(), output_ids_.end()); + + externals_.resize(input_ids_.size() + output_ids_.size()); + + return Error::Ok; +} + +/** + * Prepares the args for XNNPACK Runtime. + * + * Creates an array of xnn_externals_values from the EValues passed in. + * Reshapes all the external input tensors, in case any input shapes have + * changed. The reshapes the entire runtime, propagating shape information + * through the runtime. + * + * Note: the external ids given to the external tensors in the XNNPACK + * runtime correspond to their index in the list of arg passed into + * delegate->execute() + */ +__ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) { + // Create xnn_externals_value from evalue args + xnn_status status; + for (uint32_t i = 0; i < externals_.size(); ++i) { + if (i < input_ids_.size()) { + externals_[i].id = input_ids_[i]; + } else { + externals_[i].id = output_ids_[i - input_ids_.size()]; + } + uint32_t ext_id = externals_[i].id; + + ET_CHECK_OR_RETURN_ERROR( + args[ext_id]->isTensor(), + InvalidArgument, + "Expected argument to delegate at index %u to be a Tensor, but got %" PRIu32, + i, + static_cast(args[ext_id]->tag)); + + Tensor* tensor = &args[ext_id]->toTensor(); + externals_[i].data = tensor->mutable_data_ptr(); + + // Reshape runtime inputs + if (i < input_ids_.size()) { + size_t num_dims = tensor->dim(); + size_t dims[XNN_MAX_TENSOR_DIMS]; + for (int d = 0; d < num_dims; ++d) { + dims[d] = tensor->size(d); + } + status = + xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims); + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Internal Error: Reshape Input Tensor Failed with code: %s", + xnn_status_to_string(status)); } } - if (input->dim() != shape->num_dims) { - ET_LOG(Error, "Input dim mismatch between tensor and shape struct"); + // // Propagate Input Shape and Memory Plan for increased allocation + status = xnn_reshape_runtime(runtime_.get()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Internal Error: Propagating input shapes failed with code: %s", + xnn_status_to_string(status)); + + return Error::Ok; +} + +/** + * Runs the XNNPACK Runtime. + * + * We first setup the runtime by feeding the externals_ to runtime setup. + * After which we then execute the runtime through invoke_runtime. + */ +__ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) { + ET_CHECK_OR_RETURN_ERROR( + runtime_ != nullptr, + Internal, + "XNNPACK Delegate did not compile correctly"); + + xnn_status status = xnn_setup_runtime_v2( + runtime_.get(), externals_.size(), externals_.data()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Internal Error: Setting up the runtime failed with code: %s", + xnn_status_to_string(status)); + + auto error = profiler_.start(context.event_tracer()); + if (error != Error::Ok) { + ET_LOG( + Error, + "Failed to start profiling: %u.", + static_cast(error)); + } + + status = xnn_invoke_runtime(runtime_.get()); + + error = profiler_.end(); + if (error != Error::Ok) { + ET_LOG( + Error, + "Failed to end profiling: %u.", + static_cast(error)); + } + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "XNN Runtime invoke failed with code: %s", + xnn_status_to_string(status)); + + return Error::Ok; +} + +/** + * Prepares the outputs for ExecuTorch + * + * Resizes the output tensors based on the output shapes returned by + * the xnnpack runtime. + * + * Note: For arg_max pooling, we recast the output index tensor. Since + * XNNPACK gives the index tensor to us as int32, we need to convert it + * back to int64 for ExecuTorch. + */ +__ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const { + size_t output_idx_start = input_ids_.size(); + for (size_t i = output_idx_start; i < externals_.size(); ++i) { + uint32_t ext_id = externals_[i].id; + Tensor* out_tensor = &args[ext_id]->toTensor(); + + size_t num_dim; + size_t dims[XNN_MAX_TENSOR_DIMS]; + + // Fetch the updated output shapes from xnnpack runtime + xnn_status status = + xnn_get_external_value_shape(runtime_.get(), ext_id, &num_dim, dims); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Internal Error: Failed to retrieve graph output shapes"); + + // Convert new output shape into SizesType + SizesType expected_output_size[kTensorDimensionLimit]; + for (size_t d = 0; d < num_dim; ++d) { + expected_output_size[d] = static_cast(dims[d]); + } + + exec_aten::ArrayRef output_size{ + expected_output_size, static_cast(num_dim)}; + + ET_LOG(Debug, "Resizing output tensor to a new shape"); + Error err = resize_tensor(*out_tensor, output_size); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to resize output tensor for XNNExecutor"); + return err; + } + + // Output datatype is int64. However, XNNPACK doesn't support + // int64. This means that the data was put into this tensor + // by XNNPACK as int32 and needs to be copied to int64 form + if (out_tensor->scalar_type() == ScalarType::Long) { + int64_t* data_64 = out_tensor->mutable_data_ptr(); + const int32_t* data_32 = out_tensor->const_data_ptr(); + for (size_t j = out_tensor->numel() - 1; j >= 0; --j) { + data_64[j] = data_32[j]; + } + } } -#ifdef ENABLE_DYNAMIC_QUANTIZATION - externals_.emplace_back(xnn_external_value{ - id, - input->mutable_data_ptr(), - static_cast(shape->num_dims), - shape->dim}); -#else - externals_.emplace_back(xnn_external_value{id, input->mutable_data_ptr()}); -#endif return Error::Ok; } diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index d4a4677392..b13951bdd1 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -24,11 +24,6 @@ namespace executor { namespace xnnpack { namespace delegate { -struct XNNShape { - size_t num_dims; - size_t dim[XNN_MAX_TENSOR_DIMS]; -}; - class XNNExecutor { private: std::unique_ptr runtime_{ @@ -38,42 +33,11 @@ class XNNExecutor { profiling::XNNProfiler profiler_; std::vector input_ids_; std::vector output_ids_; - std::vector external_id_args_; - bool is_sorted_args_list_ = false; std::vector externals_; - std::vector qinputs_; - - Error set_external_input(uint32_t id, Tensor* input, struct XNNShape* shape); public: XNNExecutor() = default; - inline void append_arg(uint32_t id) { - external_id_args_.push_back(id); - // Insertion order is not guaranteed here. - is_sorted_args_list_ = false; - } - - inline size_t get_args_size() { - return external_id_args_.size(); - } - - inline uint32_t get_arg_index(size_t i) { - if (!is_sorted_args_list_) { - // Could have been inserted out of order. - sort(external_id_args_.begin(), external_id_args_.end()); - is_sorted_args_list_ = true; - } - - size_t ret = external_id_args_.size(); - ET_CHECK_MSG( - i < ret, - "Invalid arg index, requested: %zu, total args consumed by xnnpack: %zu\n", - i, - ret); - return external_id_args_[i]; - } - inline size_t getNumInputs() { return input_ids_.size(); } @@ -82,147 +46,35 @@ class XNNExecutor { return output_ids_.size(); } - inline void initialize(xnn_runtime_t runtime) { - runtime_ = std::unique_ptr( - runtime, xnn_delete_runtime); - - auto error = profiler_.initialize(runtime); - ET_CHECK_MSG( - error == Error::Ok, - "Failed to initialize profiler with error: %d", - static_cast(error)); - } - - inline void addDynamicQinput(uint32_t id) { - qinputs_.emplace_back(id); - } - - __ET_NODISCARD Error set_inputs( - std::vector& inputs, - std::vector& outputs, - std::vector& input_shapes, - std::vector& output_shapes) { - externals_.clear(); - - ET_CHECK_OR_RETURN_ERROR( - inputs.size() == input_ids_.size(), - InvalidArgument, - "Expected %zu inputs but given %zu", - input_ids_.size(), - inputs.size()); - - for (int i = 0; i < inputs.size(); i++) { - auto err = set_external_input(input_ids_[i], inputs[i], &input_shapes[i]); - ET_CHECK_OR_RETURN_ERROR( - err == Error::Ok, Internal, "Failed to set_external_input"); - } - ET_CHECK_OR_RETURN_ERROR( - outputs.size() == output_ids_.size(), - InvalidArgument, - "Expected %zu outputs gut given %zu", - output_ids_.size(), - outputs.size()); - - for (int i = 0; i < outputs.size(); i++) { -#ifdef ENABLE_DYNAMIC_QUANTIZATION - externals_.emplace_back(xnn_external_value{ - output_ids_[i], - outputs[i]->mutable_data_ptr(), - static_cast(output_shapes[i].num_dims), - output_shapes[i].dim}); -#else - externals_.emplace_back(xnn_external_value{ - output_ids_[i], outputs[i]->mutable_data_ptr()}); -#endif - } - - return Error::Ok; - } - - __ET_NODISCARD Error forward(BackendExecutionContext& context) { - ET_CHECK_OR_RETURN_ERROR( - runtime_ != nullptr, - Internal, - "XNNPACK Delegate did not compile correctly"); - xnn_status status = - xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "XNN Runtime setup failed with code: %s", - xnn_status_to_string(status)); - - auto error = profiler_.start(context.event_tracer()); - if (error != Error::Ok) { - ET_LOG( - Error, - "Failed to start profiling: %u.", - static_cast(error)); - } - - status = xnn_invoke_runtime(runtime_.get()); - - error = profiler_.end(); - if (error != Error::Ok) { - ET_LOG( - Error, - "Failed to end profiling: %u.", - static_cast(error)); - } - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "XNN Runtime invoke failed with code: %s", - xnn_status_to_string(status)); - - return Error::Ok; - } - - /** Resize output tensor to support dynamic input shapes */ - __ET_NODISCARD Error resizeOutput( - exec_aten::Tensor* output_tensor, - struct XNNShape* output_shape) const { - const size_t n_dim = output_tensor->dim(); - - // Rank can't change - if (n_dim != output_shape->num_dims) { - ET_LOG( - Error, - "Found output shape with a different number of dimensions than the output tensor. Expected: %zu, Actual: %zu", - n_dim, - output_shape->num_dims); - return Error::NotSupported; - } - - // Early exit? - bool same_shape = true; - for (size_t i = 0; (i < n_dim) && same_shape; i++) { - same_shape = (output_tensor->size(i) == output_shape->dim[i]); - } - if (same_shape) { - return Error::Ok; - } - - exec_aten::SizesType expected_output_size[kTensorDimensionLimit]; - for (size_t i = 0; i < n_dim; i++) { - expected_output_size[i] = - static_cast(output_shape->dim[i]); - } - - exec_aten::ArrayRef output_size{ - expected_output_size, static_cast(output_tensor->dim())}; - - // Ok to dereference pointer here because resize_tensor takes in a tensor - // and not a tensor& - ET_LOG(Debug, "Resizing output tensor to a new shape"); - Error err = resize_tensor(*output_tensor, output_size); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to resize output tensor for XNNExecutor"); - } - return err; - } + /** + * Initialize the XNNExecutor with a given runtime and input/output ids. + * The input/output ids are expected to be sorted in order of their + * flatbuffer id_outs + */ + __ET_NODISCARD Error initialize( + xnn_runtime_t runtime, + std::vector&& input_ids, + std::vector&& output_ids); + + /** + * Prepares the arguments for runtime graph execution. + * args is an array of EValues that will be passed into the runtime. + * input shapes will be propagated through the runtime, and perform + * any additional memory planning as needed + */ + __ET_NODISCARD Error prepare_args(EValue** args); + + /** + * Executes the graph using the args prepared at prepare_args(). + */ + __ET_NODISCARD Error forward(BackendExecutionContext& context); + + /** + * Prepares the outputs to be returned by the delegate + * + * Performs any post processing of outputs like tensor resizing + */ + __ET_NODISCARD Error resize_outputs(EValue** args) const; friend class XNNCompiler; }; diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index 7fd921ad9d..33d7ebebfe 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -60,84 +60,20 @@ class XnnpackBackend final : public PyTorchBackendInterface { EValue** args) const override { auto executor = static_cast(handle); - // TODO merge these two in a single struct? - std::vector input_pointers; - std::vector output_pointers; - std::vector input_shapes; - std::vector output_shapes; - - ET_CHECK_OR_RETURN_ERROR( - executor->get_args_size() == - executor->getNumInputs() + executor->getNumOutputs(), - Internal, - "External id and expected delegate args mismatch"); - - // Intialize XNNShapes for both inputs and outputs. - // That will allow us to gradually build up shape inference support for - // xnnpack ops without breaking existing models when we always try to resize - // delegate output tensor(s). - input_shapes.resize(executor->getNumInputs()); - output_shapes.resize(executor->getNumOutputs()); - - for (int i = 0; i < executor->get_args_size(); i++) { - int index = executor->get_arg_index(i); - - if (!args[index]->isTensor()) { - ET_LOG(Error, "Expected argument to be a tensor"); - } - - Tensor* tensor = &args[index]->toTensor(); - size_t num_dims = tensor->dim(); - struct xnnpack::delegate::XNNShape* shape = nullptr; - - if (i < executor->getNumInputs()) { - input_pointers.push_back(tensor); - shape = &input_shapes[i]; - } else { - output_pointers.push_back(tensor); - shape = &output_shapes[i - executor->getNumInputs()]; - } - - shape->num_dims = num_dims; - for (int d = 0; d < num_dims; ++d) { - shape->dim[d] = tensor->size(d); - } - } - - Error err = executor->set_inputs( - input_pointers, output_pointers, input_shapes, output_shapes); - + // Prepare Inputs/Outputs and Propagate Input Shapes + Error err = executor->prepare_args(args); if (err != Error::Ok) { return err; } err = executor->forward(context); - // Resize output tensors - should be a no-op for static models - for (int i = 0; i < executor->getNumOutputs(); i++) { - err = executor->resizeOutput(output_pointers[i], &output_shapes[i]); - if (err != Error::Ok) { - return err; - } + if (err != Error::Ok) { + return err; } - for (int i = executor->getNumInputs(); - i < executor->getNumInputs() + executor->getNumOutputs(); - i++) { - if (args[i]->isTensor()) { - exec_aten::Tensor output_tensor = args[i]->toTensor(); - if (output_tensor.scalar_type() == ScalarType::Long) { - // Output datatype is int64. However, XNNPACK doesn't support - // int64. This means that the data was put into this tensor - // by XNNPACK as int32 and needs to be copied to int64 form - int64_t* data_64 = output_tensor.mutable_data_ptr(); - const int32_t* data_32 = output_tensor.const_data_ptr(); - for (int j = output_tensor.numel() - 1; j >= 0; j--) { - data_64[j] = data_32[j]; - } - } - } - } + // Resize outputs and recast pointers if necessary + err = executor->resize_outputs(args); return err; } diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index 92043b5fea..915549155d 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -37,7 +37,7 @@ def define_common_targets(): ], preprocessor_flags = [ # "-DENABLE_XNNPACK_PROFILING", - ] + ([] if runtime.is_oss else ["-DENABLE_DYNAMIC_QUANTIZATION"]), + ], deps = [ third_party_dep("XNNPACK"), "//executorch/runtime/backend:interface", diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index ae8fcc23db..ec03fa2529 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -501,12 +501,25 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): assert len(model_output) == len(ref_output) for i in range(len(model_output)): + model = model_output[i] + ref = ref_output[i] assert torch.allclose( - model_output[i], - ref_output[i], + model, + ref, atol=atol, rtol=rtol, - ), f" Output {i} does not match reference output. Max difference: {torch.max(torch.abs(model_output[i] - ref_output[i]))}" + ), ( + f"Output {i} does not match reference output.\n" + f"\tGiven atol: {atol}, rtol: {rtol}.\n" + f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" + f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}.\n" + f"\t-- Model vs. Reference --\n" + f"\t Numel: {model.numel()}, {ref.numel()}\n" + f"\tMedian: {model.median()}, {ref.median()}\n" + f"\t Mean: {model.mean()}, {ref.mean()}\n" + f"\t Max: {model.max()}, {ref.max()}\n" + f"\t Min: {model.min()}, {ref.min()}\n" + ) def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0): """ diff --git a/backends/xnnpack/third-party/XNNPACK b/backends/xnnpack/third-party/XNNPACK index d9cce341f8..fcbf55af6c 160000 --- a/backends/xnnpack/third-party/XNNPACK +++ b/backends/xnnpack/third-party/XNNPACK @@ -1 +1 @@ -Subproject commit d9cce341f86a207da9d851d05e26cd50b508b73c +Subproject commit fcbf55af6cf28a4627bcd1f703ab7ad843f0f3a2 diff --git a/backends/xnnpack/third-party/generate-xnnpack-wrappers.py b/backends/xnnpack/third-party/generate-xnnpack-wrappers.py index d378f39489..bda7952717 100644 --- a/backends/xnnpack/third-party/generate-xnnpack-wrappers.py +++ b/backends/xnnpack/third-party/generate-xnnpack-wrappers.py @@ -36,6 +36,7 @@ "PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", "PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", "PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", + "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", "PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", "PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)", "PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)", @@ -80,6 +81,7 @@ "PROD_AVX512F_MICROKERNEL_SRCS", "PROD_AVX512SKX_MICROKERNEL_SRCS", "PROD_AVX512VBMI_MICROKERNEL_SRCS", + "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS", "PROD_AVX512VNNI_MICROKERNEL_SRCS", "PROD_RVV_MICROKERNEL_SRCS", "PROD_AVXVNNI_MICROKERNEL_SRCS", diff --git a/backends/xnnpack/third-party/xnnpack.buck.bzl b/backends/xnnpack/third-party/xnnpack.buck.bzl index 4f3dafb5cb..a1add44664 100644 --- a/backends/xnnpack/third-party/xnnpack.buck.bzl +++ b/backends/xnnpack/third-party/xnnpack.buck.bzl @@ -17,6 +17,7 @@ load( "PROD_AVX512F_MICROKERNEL_SRCS", "PROD_AVX512SKX_MICROKERNEL_SRCS", "PROD_AVX512VBMI_MICROKERNEL_SRCS", + "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS", "PROD_AVX512VNNI_MICROKERNEL_SRCS", "PROD_AVXVNNI_MICROKERNEL_SRCS", "PROD_AVX_MICROKERNEL_SRCS", @@ -1125,9 +1126,45 @@ def define_xnnpack(): ], ) + AVX512VNNIGFNI_COMPILER_FLAGS = AVX512VNNI_COMPILER_FLAGS + [ + "-mgfni", + ] + + # @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode. + native.cxx_library( + name = "ukernels_avx512vnnigfni", + srcs = select({ + "DEFAULT": PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS, + "ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC, + "ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC, + }), + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), + header_namespace = "", + compiler_flags = [ + "-O2", + "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default + ] + select({ + "DEFAULT": AVX512VNNIGFNI_COMPILER_FLAGS, + "ovr_config//cpu:arm32": [], + "ovr_config//cpu:arm64": [], + }), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + exported_deps = [ + ":interface", + ], + ) + AVXVNNI_COMPILER_FLAGS = [ "-mavx2", "-mavxvnni", + "-mf16c", + "-mfma", ] # @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode. @@ -1180,6 +1217,7 @@ def define_xnnpack(): ":ukernels_ssse3", ":ukernels_xop", ":ukernels_avx512vbmi", + ":ukernels_avx512vnnigfni", ":ukernels_avx512vnni", ":ukernels_avxvnni", ] diff --git a/backends/xnnpack/third-party/xnnpack_src_defs.bzl b/backends/xnnpack/third-party/xnnpack_src_defs.bzl index a9e60e4f17..0a0beba7ef 100644 --- a/backends/xnnpack/third-party/xnnpack_src_defs.bzl +++ b/backends/xnnpack/third-party/xnnpack_src_defs.bzl @@ -2,99 +2,32 @@ Auto-generated by generate-wrappers.py script. Do not modify """ -SUBGRAPH_SRCS = [ - "XNNPACK/src/memory-planner.c", - "XNNPACK/src/runtime.c", - "XNNPACK/src/subgraph.c", - "XNNPACK/src/subgraph/abs.c", - "XNNPACK/src/subgraph/add2.c", - "XNNPACK/src/subgraph/argmax-pooling-2d.c", - "XNNPACK/src/subgraph/average-pooling-2d.c", - "XNNPACK/src/subgraph/bankers-rounding.c", - "XNNPACK/src/subgraph/batch-matrix-multiply.c", - "XNNPACK/src/subgraph/ceiling.c", - "XNNPACK/src/subgraph/clamp.c", - "XNNPACK/src/subgraph/concatenate.c", - "XNNPACK/src/subgraph/convert.c", - "XNNPACK/src/subgraph/convolution-2d.c", - "XNNPACK/src/subgraph/copy.c", - "XNNPACK/src/subgraph/deconvolution-2d.c", - "XNNPACK/src/subgraph/depth-to-space-2d.c", - "XNNPACK/src/subgraph/depthwise-convolution-2d.c", - "XNNPACK/src/subgraph/divide.c", - "XNNPACK/src/subgraph/elu.c", - "XNNPACK/src/subgraph/even-split.c", - "XNNPACK/src/subgraph/floor.c", - "XNNPACK/src/subgraph/fully-connected-sparse.c", - "XNNPACK/src/subgraph/fully-connected.c", - "XNNPACK/src/subgraph/global-average-pooling.c", - "XNNPACK/src/subgraph/global-sum-pooling.c", - "XNNPACK/src/subgraph/hardswish.c", - "XNNPACK/src/subgraph/leaky-relu.c", - "XNNPACK/src/subgraph/max-pooling-2d.c", - "XNNPACK/src/subgraph/maximum2.c", - "XNNPACK/src/subgraph/minimum2.c", - "XNNPACK/src/subgraph/multiply2.c", - "XNNPACK/src/subgraph/negate.c", - "XNNPACK/src/subgraph/prelu.c", - "XNNPACK/src/subgraph/scaled-dot-product-attention.c", - "XNNPACK/src/subgraph/sigmoid.c", - "XNNPACK/src/subgraph/softmax.c", - "XNNPACK/src/subgraph/space-to-depth-2d.c", - "XNNPACK/src/subgraph/square-root.c", - "XNNPACK/src/subgraph/square.c", - "XNNPACK/src/subgraph/squared-difference.c", - "XNNPACK/src/subgraph/static-constant-pad.c", - "XNNPACK/src/subgraph/static-mean.c", - "XNNPACK/src/subgraph/static-reshape.c", - "XNNPACK/src/subgraph/static-resize-bilinear-2d.c", - "XNNPACK/src/subgraph/static-slice.c", - "XNNPACK/src/subgraph/static-transpose.c", - "XNNPACK/src/subgraph/subtract.c", - "XNNPACK/src/subgraph/tanh.c", - "XNNPACK/src/subgraph/unpooling-2d.c", - "XNNPACK/src/subgraph/validation.c", - "XNNPACK/src/tensor.c", -] - -PROD_NEONI8MM_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neoni8mm.c", -] - -PROD_SSE41_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse41.c", -] - -PROD_SSSE3_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/ssse3.c", +PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neondotfp16-aarch64.c", ] -PROD_AVX512VBMI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vbmi.c", +PROD_FP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fp16arith.c", ] -LOGGING_SRCS = [ - "XNNPACK/src/enums/datatype-strings.c", - "XNNPACK/src/enums/microkernel-type.c", - "XNNPACK/src/enums/node-type.c", - "XNNPACK/src/enums/operator-type.c", - "XNNPACK/src/log.c", +PROD_SSE_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse.c", ] -PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neondotfp16arith.c", +PROD_FMA3_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fma3.c", ] -PROD_F16C_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/f16c.c", +PROD_SSE2_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse2.c", ] -PROD_NEON_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neon.c", +PROD_NEONV8_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonv8.c", ] -PROD_RVV_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/rvv.c", +PROD_AVX512SKX_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512skx.c", ] AARCH32_ASM_MICROKERNEL_SRCS = [ @@ -124,28 +57,9 @@ AARCH32_ASM_MICROKERNEL_SRCS = [ "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-ld64.S", "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", + "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-ld64.S", "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p8c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p16c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", @@ -206,44 +120,17 @@ AARCH32_ASM_MICROKERNEL_SRCS = [ "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x2.S", ] -PROD_AVX_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx.c", -] - -PROD_XOP_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/xop.c", -] - -PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16arith.c", -] - -PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neondot-aarch64.c", -] - -PROD_AVX512SKX_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512skx.c", -] - -PROD_ARMSIMD32_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/armsimd32.c", -] - -PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16arith-aarch64.c", -] - -PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neondotfp16-aarch64.c", +PROD_FMA_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fma.c", ] -PROD_AVX512VNNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vnni.c", +PROD_NEON_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neon-aarch64.c", + "XNNPACK/src/amalgam/gen/neonfma-aarch64.c", ] -PROD_SSE2_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse2.c", +PROD_SSSE3_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/ssse3.c", ] TABLE_SRCS = [ @@ -258,31 +145,20 @@ TABLE_SRCS = [ "XNNPACK/src/tables/vlog.c", ] -PROD_FP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fp16arith.c", -] - -PROD_NEONFMA_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfma.c", -] - -JIT_SRCS = [ - "XNNPACK/src/jit/aarch32-assembler.cc", - "XNNPACK/src/jit/aarch64-assembler.cc", - "XNNPACK/src/jit/assembler.cc", +PROD_AVX512VBMI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vbmi.c", ] PROD_AVX2_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/avx2.c", ] -PROD_AVXVNNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avxvnni.c", +PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neondotfp16arith.c", ] -PROD_NEON_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neon-aarch64.c", - "XNNPACK/src/amalgam/gen/neonfma-aarch64.c", +PROD_AVXVNNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avxvnni.c", ] OPERATOR_SRCS = [ @@ -315,20 +191,80 @@ OPERATOR_SRCS = [ "XNNPACK/src/operators/unpooling-nhwc.c", ] -PROD_NEONDOT_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neondot.c", +PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16arith.c", ] -PROD_SCALAR_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/scalar.c", +PROD_F16C_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/f16c.c", +] + +PROD_XOP_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/xop.c", ] PROD_AVX512F_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/avx512f.c", ] -PROD_FMA3_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fma3.c", +SUBGRAPH_SRCS = [ + "XNNPACK/src/memory-planner.c", + "XNNPACK/src/runtime.c", + "XNNPACK/src/subgraph.c", + "XNNPACK/src/subgraph/abs.c", + "XNNPACK/src/subgraph/add2.c", + "XNNPACK/src/subgraph/argmax-pooling-2d.c", + "XNNPACK/src/subgraph/average-pooling-2d.c", + "XNNPACK/src/subgraph/bankers-rounding.c", + "XNNPACK/src/subgraph/batch-matrix-multiply.c", + "XNNPACK/src/subgraph/ceiling.c", + "XNNPACK/src/subgraph/clamp.c", + "XNNPACK/src/subgraph/concatenate.c", + "XNNPACK/src/subgraph/convert.c", + "XNNPACK/src/subgraph/convolution-2d.c", + "XNNPACK/src/subgraph/copy.c", + "XNNPACK/src/subgraph/deconvolution-2d.c", + "XNNPACK/src/subgraph/depth-to-space-2d.c", + "XNNPACK/src/subgraph/depthwise-convolution-2d.c", + "XNNPACK/src/subgraph/divide.c", + "XNNPACK/src/subgraph/elu.c", + "XNNPACK/src/subgraph/even-split.c", + "XNNPACK/src/subgraph/floor.c", + "XNNPACK/src/subgraph/fully-connected-sparse.c", + "XNNPACK/src/subgraph/fully-connected.c", + "XNNPACK/src/subgraph/global-average-pooling.c", + "XNNPACK/src/subgraph/global-sum-pooling.c", + "XNNPACK/src/subgraph/hardswish.c", + "XNNPACK/src/subgraph/leaky-relu.c", + "XNNPACK/src/subgraph/max-pooling-2d.c", + "XNNPACK/src/subgraph/maximum2.c", + "XNNPACK/src/subgraph/minimum2.c", + "XNNPACK/src/subgraph/multiply2.c", + "XNNPACK/src/subgraph/negate.c", + "XNNPACK/src/subgraph/prelu.c", + "XNNPACK/src/subgraph/reshape-helpers.c", + "XNNPACK/src/subgraph/scaled-dot-product-attention.c", + "XNNPACK/src/subgraph/sigmoid.c", + "XNNPACK/src/subgraph/softmax.c", + "XNNPACK/src/subgraph/space-to-depth-2d.c", + "XNNPACK/src/subgraph/square-root.c", + "XNNPACK/src/subgraph/square.c", + "XNNPACK/src/subgraph/squared-difference.c", + "XNNPACK/src/subgraph/static-constant-pad.c", + "XNNPACK/src/subgraph/static-mean.c", + "XNNPACK/src/subgraph/static-reshape.c", + "XNNPACK/src/subgraph/static-resize-bilinear-2d.c", + "XNNPACK/src/subgraph/static-slice.c", + "XNNPACK/src/subgraph/static-transpose.c", + "XNNPACK/src/subgraph/subtract.c", + "XNNPACK/src/subgraph/tanh.c", + "XNNPACK/src/subgraph/unpooling-2d.c", + "XNNPACK/src/subgraph/validation.c", + "XNNPACK/src/tensor.c", +] + +PROD_RVV_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/rvv.c", ] AARCH64_ASM_MICROKERNEL_SRCS = [ @@ -514,87 +450,13 @@ AARCH64_ASM_MICROKERNEL_SRCS = [ "XNNPACK/src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld128.S", "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondotfp16arith-cortex-a55.S", + "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S", + "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S", "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S", "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S", "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-rndnu-asm-aarch64-neondot-ld32.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-rndnu-asm-aarch64-neondot-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mull.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mull.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c16-minmax-fp32-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-2x8c16-minmax-rndnu-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld32.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld64.S", - "XNNPACK/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c16-minmax-fp32-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-2x8c16-minmax-rndnu-asm-aarch64-neon-mlal.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld64.S", - "XNNPACK/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", @@ -691,18 +553,68 @@ XNNPACK_SRCS = [ "XNNPACK/src/params.c", ] -PROD_FMA_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fma.c", +JIT_SRCS = [ + "XNNPACK/src/jit/aarch32-assembler.cc", + "XNNPACK/src/jit/aarch64-assembler.cc", + "XNNPACK/src/jit/assembler.cc", ] -PROD_NEONV8_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonv8.c", +PROD_NEONFMA_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfma.c", +] + +PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16arith-aarch64.c", ] PROD_NEONFP16_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neonfp16.c", ] -PROD_SSE_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse.c", +PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vnnigfni.c", +] + +PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neondot-aarch64.c", +] + +PROD_ARMSIMD32_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/armsimd32.c", +] + +PROD_NEONDOT_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neondot.c", +] + +PROD_SCALAR_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/scalar.c", +] + +PROD_SSE41_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse41.c", +] + +PROD_NEONI8MM_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neoni8mm.c", +] + +LOGGING_SRCS = [ + "XNNPACK/src/enums/datatype-strings.c", + "XNNPACK/src/enums/microkernel-type.c", + "XNNPACK/src/enums/node-type.c", + "XNNPACK/src/enums/operator-type.c", + "XNNPACK/src/log.c", +] + +PROD_NEON_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neon.c", +] + +PROD_AVX512VNNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vnni.c", +] + +PROD_AVX_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx.c", ] diff --git a/backends/xnnpack/third-party/xnnpack_wrapper_defs.bzl b/backends/xnnpack/third-party/xnnpack_wrapper_defs.bzl index 256633ff55..2dbb41ff01 100644 --- a/backends/xnnpack/third-party/xnnpack_wrapper_defs.bzl +++ b/backends/xnnpack/third-party/xnnpack_wrapper_defs.bzl @@ -115,6 +115,10 @@ PROD_AVX512VBMI_MICROKERNEL_SRCS = [ "xnnpack_wrappers/amalgam/gen/avx512vbmi.c", ] +PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/amalgam/gen/avx512vnnigfni.c", +] + PROD_AVX512VNNI_MICROKERNEL_SRCS = [ "xnnpack_wrappers/amalgam/gen/avx512vnni.c", ] @@ -154,28 +158,9 @@ AARCH32_ASM_MICROKERNEL_SRCS = [ "xnnpack_wrappers/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", "xnnpack_wrappers/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-ld64.S", "xnnpack_wrappers/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", + "xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", "xnnpack_wrappers/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", "xnnpack_wrappers/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8c4-minmax-rndnu-asm-aarch32-neondot-ld64.S", "xnnpack_wrappers/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p8c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", "xnnpack_wrappers/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p16c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", "xnnpack_wrappers/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", @@ -419,87 +404,13 @@ AARCH64_ASM_MICROKERNEL_SRCS = [ "xnnpack_wrappers/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld128.S", "xnnpack_wrappers/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondotfp16arith-cortex-a55.S", + "xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S", "xnnpack_wrappers/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S", "xnnpack_wrappers/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S", "xnnpack_wrappers/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-rndnu-asm-aarch64-neondot-ld32.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-rndnu-asm-aarch64-neondot-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-fp32-asm-aarch64-neon-mull.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mull.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c16-minmax-fp32-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-2x8c16-minmax-rndnu-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld32.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld64.S", - "xnnpack_wrappers/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-1x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-fp32-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c8-minmax-rndnu-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c16-minmax-fp32-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-2x8c16-minmax-rndnu-asm-aarch64-neon-mlal.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x8-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-fp32-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld64.S", - "xnnpack_wrappers/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53-prfm.S", "xnnpack_wrappers/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", "xnnpack_wrappers/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-prfm.S", diff --git a/backends/xnnpack/third-party/xnnpack_wrappers/amalgam/gen/avx512vnnigfni.c b/backends/xnnpack/third-party/xnnpack_wrappers/amalgam/gen/avx512vnnigfni.c new file mode 100644 index 0000000000..c02b5a09ae --- /dev/null +++ b/backends/xnnpack/third-party/xnnpack_wrappers/amalgam/gen/avx512vnnigfni.c @@ -0,0 +1,5 @@ +/* Auto-generated by generate-wrappers.py script. Do not modify */ + +#if defined(__i386__) || defined(__i686__) || defined(__x86_64__) +#include +#endif /* defined(__i386__) || defined(__i686__) || defined(__x86_64__) */ diff --git a/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S b/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S new file mode 100644 index 0000000000..54c754b5d8 --- /dev/null +++ b/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-cortex-a55.S @@ -0,0 +1,5 @@ +/* Auto-generated by generate-wrappers.py script. Do not modify */ + +#if defined(__aarch64__) +#include +#endif /* defined(__aarch64__) */ diff --git a/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S b/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S new file mode 100644 index 0000000000..41bee5072d --- /dev/null +++ b/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x16c4-minmax-asm-aarch64-neondot-ld128.S @@ -0,0 +1,5 @@ +/* Auto-generated by generate-wrappers.py script. Do not modify */ + +#if defined(__aarch64__) +#include +#endif /* defined(__aarch64__) */ diff --git a/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S b/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S new file mode 100644 index 0000000000..db2eda8704 --- /dev/null +++ b/backends/xnnpack/third-party/xnnpack_wrappers/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S @@ -0,0 +1,5 @@ +/* Auto-generated by generate-wrappers.py script. Do not modify */ + +#if defined(__arm__) +#include +#endif /* defined(__arm__) */ diff --git a/build/build_apple_frameworks.sh b/build/build_apple_frameworks.sh index 5ac8cd2c1d..61e7c2725f 100755 --- a/build/build_apple_frameworks.sh +++ b/build/build_apple_frameworks.sh @@ -11,20 +11,21 @@ PLATFORMS=("iphoneos" "iphonesimulator") PLATFORM_FLAGS=("OS" "SIMULATOR") SOURCE_ROOT_DIR="" OUTPUT="cmake-out" -MODE="Debug" +MODE="Release" TOOLCHAIN="" -BUCK2="/tmp/buck2" +BUCK2="buck2" PYTHON=$(which python3) -FLATC="" +FLATC="flatc" IOS_DEPLOYMENT_TARGET="17.0" COREML=OFF MPS=OFF +PORTABLE=OFF XNNPACK=OFF HEADERS_PATH="include" -EXECUTORCH_FRAMEWORK="executorch:libexecutorch.a,libextension_data_loader.a,libextension_module.a:$HEADERS_PATH" -PORTABLE_FRAMEWORK="portable_backend:libportable_kernels.a,libportable_ops_lib.a:" +EXECUTORCH_FRAMEWORK="executorch:libexecutorch.a,libextension_apple.a,libextension_data_loader.a,libextension_module.a:$HEADERS_PATH" COREML_FRAMEWORK="coreml_backend:libcoremldelegate.a:" MPS_FRAMEWORK="mps_backend:libmpsdelegate.a:" +PORTABLE_FRAMEWORK="portable_backend:libportable_kernels.a,libportable_ops_lib.a:" XNNPACK_FRAMEWORK="xnnpack_backend:libXNNPACK.a,libcpuinfo.a,libpthreadpool.a,libxnnpack_backend.a:" usage() { @@ -34,14 +35,15 @@ usage() { echo echo "Options:" echo " --output=DIR Output directory. Default: 'cmake-out'" - echo " --Release Use Release build mode. Default: 'Debug'" + echo " --Debug Use Debug build mode. Default: 'Release'" echo " --toolchain=FILE Cmake toolchain file. Default: '\$SOURCE_ROOT_DIR/third-party/pytorch/cmake/iOS.cmake'" echo " --buck2=FILE Buck2 executable path. Default: '/tmp/buck2'" echo " --python=FILE Python executable path. Default: Path of python3 found in the current \$PATH" echo " --flatc=FILE FlatBuffers Compiler executable path. Default: '\$SOURCE_ROOT_DIR/third-party/flatbuffers/cmake-out/flatc'" - echo " --coreml Include this flag to build Core ML backend." - echo " --mps Include this flag to build Metal Performance Shaders backend." - echo " --xnnpack Include this flag to build XNNPACK backend." + echo " --coreml Include this flag to build the Core ML backend." + echo " --mps Include this flag to build the Metal Performance Shaders backend." + echo " --portable Include this flag to build the Portable backend." + echo " --xnnpack Include this flag to build the XNNPACK backend." echo echo "Example:" echo " $0 /path/to/source/root --output=cmake-out --Release --toolchain=/path/to/cmake/toolchain --buck2=/path/to/buck2 --python=/path/to/python3 --coreml --mps --xnnpack" @@ -52,13 +54,14 @@ for arg in "$@"; do case $arg in -h|--help) usage ;; --output=*) OUTPUT="${arg#*=}" ;; - --Release) MODE="Release" ;; + --Debug) MODE="Debug" ;; --toolchain=*) TOOLCHAIN="${arg#*=}" ;; --buck2=*) BUCK2="${arg#*=}" ;; --python=*) PYTHON="${arg#*=}" ;; --flatc=*) FLATC="${arg#*=}" ;; --ios-deployment-target=*) IOS_DEPLOYMENT_TARGET="${arg#*=}" ;; --coreml) COREML=ON ;; + --portable) PORTABLE=ON ;; --mps) MPS=ON ;; --xnnpack) XNNPACK=ON ;; *) @@ -105,12 +108,14 @@ cmake_build() { echo "Building for $platform with flag $platform_flag" mkdir "$platform" && cd "$platform" || exit 1 cmake "$SOURCE_ROOT_DIR" -G Xcode \ + -DCMAKE_BUILD_TYPE="$MODE" \ -DCMAKE_TOOLCHAIN_FILE="$TOOLCHAIN" \ -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD="c++17" \ -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LIBRARY="libc++" \ -DBUCK2="$BUCK2" \ -DPYTHON_EXECUTABLE="$PYTHON" \ -DFLATC_EXECUTABLE="$FLATC" \ + -DEXECUTORCH_BUILD_EXTENSION_APPLE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY="$(pwd)" \ @@ -135,6 +140,8 @@ mkdir -p "$HEADERS_PATH" //extension/module: \ | rsync -av --files-from=- "$SOURCE_ROOT_DIR" "$HEADERS_PATH/executorch" +cp "$SOURCE_ROOT_DIR/extension/apple/ExecuTorch/Exported/"{*.h,*.modulemap} "$HEADERS_PATH" + echo "Creating frameworks" for platform in "${PLATFORMS[@]}"; do @@ -152,7 +159,7 @@ append_framework_flag() { } append_framework_flag "ON" "$EXECUTORCH_FRAMEWORK" -append_framework_flag "ON" "$PORTABLE_FRAMEWORK" +append_framework_flag "$PORTABLE" "$PORTABLE_FRAMEWORK" append_framework_flag "$COREML" "$COREML_FRAMEWORK" append_framework_flag "$MPS" "$MPS_FRAMEWORK" append_framework_flag "$XNNPACK" "$XNNPACK_FRAMEWORK" diff --git a/build/cmake_deps.toml b/build/cmake_deps.toml index bf1899534a..962eedf1dc 100644 --- a/build/cmake_deps.toml +++ b/build/cmake_deps.toml @@ -31,7 +31,7 @@ filters = [ buck_targets = [ # //kernels/portable:operators would be more appropriate, but buck2 doesn't # think it has any "inputs" since its srcs list is empty. - "//kernels/portable:generated_lib_all_ops", + "//kernels/portable:generated_lib", ] filters = [ ".cpp$", @@ -246,9 +246,10 @@ excludes = [ ] deps = [ "executorch", + "extension_data_loader", "extension_module", "portable_kernels", "quantized_kernels", "xnnpack_backend", ] -# ---------------------------------- LLama start ---------------------------------- +# ---------------------------------- LLama end ---------------------------------- diff --git a/build/packaging/post_build_script.sh b/build/packaging/post_build_script.sh new file mode 100644 index 0000000000..fd71b18565 --- /dev/null +++ b/build/packaging/post_build_script.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -eux + +echo "This script is run after building ExecuTorch binaries" diff --git a/build/packaging/pre_build_script.sh b/build/packaging/pre_build_script.sh new file mode 100644 index 0000000000..3940168c40 --- /dev/null +++ b/build/packaging/pre_build_script.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -eux + +echo "This script is run before building ExecuTorch binaries" diff --git a/build/packaging/smoke_test.py b/build/packaging/smoke_test.py new file mode 100644 index 0000000000..5273a457f1 --- /dev/null +++ b/build/packaging/smoke_test.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def main(): + """ + Run ExecuTorch binary smoke tests. This is a placeholder for future tests. See + https://github.com/pytorch/test-infra/wiki/Using-Nova-Reusable-Build-Workflows + for more information about Nova binary workflow. + """ + + +if __name__ == "__main__": + main() diff --git a/build/test_ios.sh b/build/test_ios.sh index 8c65eb6080..7b0b31262e 100755 --- a/build/test_ios.sh +++ b/build/test_ios.sh @@ -70,7 +70,7 @@ curl -LO "https://github.com/facebook/buck2/releases/download/$BUCK2_RELEASE_DAT zstd -cdq "$BUCK2_ARCHIVE" > "$BUCK2" && chmod +x "$BUCK2" rm "$BUCK2_ARCHIVE" -./install_requirements.sh +./install_requirements.sh --pybind coreml mps xnnpack export PATH="$(realpath third-party/flatbuffers/cmake-out):$PATH" ./build/install_flatc.sh @@ -82,18 +82,10 @@ say "Installing MPS Backend Requirements" ./backends/apple/mps/install_requirements.sh -say "Installing Python Bindings" - -EXECUTORCH_BUILD_PYBIND=ON \ -BUCK="$(pwd)/$BUCK2" \ -CMAKE_ARGS="-DEXECUTORCH_BUILD_COREML=ON -DEXECUTORCH_BUILD_MPS=ON -DEXECUTORCH_BUILD_XNNPACK=ON" \ -CMAKE_BUILD_PARALLEL_LEVEL=9 \ -pip install . --no-build-isolation -v - say "Exporting Models" python3 -m examples.portable.scripts.export --model_name="$MODEL_NAME" -python3 -m examples.apple.coreml.scripts.export_and_delegate --model_name="$MODEL_NAME" +python3 -m examples.apple.coreml.scripts.export --model_name="$MODEL_NAME" python3 -m examples.apple.mps.scripts.mps_example --model_name="$MODEL_NAME" python3 -m examples.xnnpack.aot_compiler --model_name="$MODEL_NAME" --delegate @@ -107,7 +99,7 @@ curl https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt \ say "Building Frameworks" -./build/build_apple_frameworks.sh --buck2="$(realpath $BUCK2)" --Release --coreml --mps --xnnpack +./build/build_apple_frameworks.sh --buck2="$(realpath $BUCK2)" --coreml --mps --portable --xnnpack mv cmake-out "$APP_PATH/Frameworks" say "Creating Simulator" diff --git a/build/test_ios_ci.sh b/build/test_ios_ci.sh index 41aeb127eb..8ca8cb4caa 100755 --- a/build/test_ios_ci.sh +++ b/build/test_ios_ci.sh @@ -39,18 +39,10 @@ say "Installing MPS Backend Requirements" ./backends/apple/mps/install_requirements.sh -say "Installing Python Bindings" - -EXECUTORCH_BUILD_PYBIND=ON \ -BUCK="$(which buck2)" \ -CMAKE_ARGS="-DEXECUTORCH_BUILD_COREML=ON -DEXECUTORCH_BUILD_MPS=ON -DEXECUTORCH_BUILD_XNNPACK=ON" \ -CMAKE_BUILD_PARALLEL_LEVEL=9 \ -pip install . --no-build-isolation -v - say "Exporting Models" python3 -m examples.portable.scripts.export --model_name="$MODEL_NAME" --segment_alignment=0x4000 -python3 -m examples.apple.coreml.scripts.export_and_delegate --model_name="$MODEL_NAME" +python3 -m examples.apple.coreml.scripts.export --model_name="$MODEL_NAME" python3 -m examples.apple.mps.scripts.mps_example --model_name="$MODEL_NAME" python3 -m examples.xnnpack.aot_compiler --model_name="$MODEL_NAME" --delegate @@ -64,7 +56,7 @@ curl https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt \ say "Building Frameworks" -./build/build_apple_frameworks.sh --buck2="$(which buck2)" --flatc="$(which flatc)" --coreml --mps --xnnpack +./build/build_apple_frameworks.sh --buck2="$(which buck2)" --flatc="$(which flatc)" --coreml --mps --portable --xnnpack mv cmake-out "$APP_PATH/Frameworks" say "Creating Simulator" diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 86d47dd561..7d259b2e14 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -62,7 +62,7 @@ {% block footer %} {{ super() }} {{ super() }} diff --git a/docs/source/build-run-coreml.md b/docs/source/build-run-coreml.md index 66983a8f0f..c442b2cc6b 100644 --- a/docs/source/build-run-coreml.md +++ b/docs/source/build-run-coreml.md @@ -59,7 +59,7 @@ xcode-select --install cd executorch # Generates ./mv3_coreml_all.pte file. -python3 -m examples.apple.coreml.scripts.export_and_delegate --model_name mv3 +python3 -m examples.apple.coreml.scripts.export --model_name mv3 ``` - Core ML backend uses [coremltools](https://apple.github.io/coremltools/docs-guides/source/overview-coremltools.html) to lower [Edge dialect](ir-exir.md#edge-dialect) to Core ML format and then bundles it in the `.pte` file. diff --git a/docs/source/index.rst b/docs/source/index.rst index e685f4cebe..871f4aba87 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -79,6 +79,17 @@ Topics in this section will help you get started with ExecuTorch. getting-started-setup runtime-build-and-cross-compilation +.. toctree:: + :glob: + :maxdepth: 2 + :caption: Working with LLMs + :hidden: + + llm/introduction + llm/mobile/index + llm/desktop/index + llm/advanced-flow/index + .. toctree:: :glob: :maxdepth: 1 diff --git a/docs/source/llm/advanced-flow/advanced-flow.md b/docs/source/llm/advanced-flow/advanced-flow.md new file mode 100644 index 0000000000..f8b596340f --- /dev/null +++ b/docs/source/llm/advanced-flow/advanced-flow.md @@ -0,0 +1,7 @@ +# Advanced Flows + +## Custom quantization + +## Bring GGUF to PyTorch ecosystem + +## TorchTune interoperability diff --git a/docs/source/llm/advanced-flow/index.rst b/docs/source/llm/advanced-flow/index.rst new file mode 100644 index 0000000000..f1ab697bb8 --- /dev/null +++ b/docs/source/llm/advanced-flow/index.rst @@ -0,0 +1,9 @@ +Enabling LLMs on Advanced Flow +======================= + +This section will walk you through + +.. toctree:: + :maxdepth: 1 + + advanced-flow diff --git a/docs/source/llm/desktop/benchmarks.md b/docs/source/llm/desktop/benchmarks.md new file mode 100644 index 0000000000..9a827aca07 --- /dev/null +++ b/docs/source/llm/desktop/benchmarks.md @@ -0,0 +1,7 @@ +# Enabling LLMs on Desktop + +## Local Llama on Desktop Benchmarks + +**Results** + +**Instructions** diff --git a/docs/source/llm/desktop/index.rst b/docs/source/llm/desktop/index.rst new file mode 100644 index 0000000000..71e77a9f11 --- /dev/null +++ b/docs/source/llm/desktop/index.rst @@ -0,0 +1,9 @@ +Enabling LLMs on Desktop +======================= + +This section will walk you through + +.. toctree:: + :maxdepth: 1 + + benchmarks diff --git a/docs/source/llm/introduction.md b/docs/source/llm/introduction.md new file mode 100644 index 0000000000..a97310eb4a --- /dev/null +++ b/docs/source/llm/introduction.md @@ -0,0 +1,7 @@ +# Introduction + +## Current landscape of local LLMs + +## What is our offering? + +## Why and when should you use it? diff --git a/docs/source/llm/mobile/benchmarks.md b/docs/source/llm/mobile/benchmarks.md new file mode 100644 index 0000000000..e056d59d95 --- /dev/null +++ b/docs/source/llm/mobile/benchmarks.md @@ -0,0 +1,6 @@ +# Mobile Benchmarks for Local Llama + +## Results + + +## Instructions diff --git a/docs/source/llm/mobile/customization-examples.md b/docs/source/llm/mobile/customization-examples.md new file mode 100644 index 0000000000..1364304f8d --- /dev/null +++ b/docs/source/llm/mobile/customization-examples.md @@ -0,0 +1,9 @@ +# Customization examples + +## Custom tokenization + +## Custom sampler + +## Speculative decoding + +## Modify a mobile app to use a different LLM model diff --git a/docs/source/llm/mobile/getting-started.md b/docs/source/llm/mobile/getting-started.md new file mode 100644 index 0000000000..c3d0b11d6d --- /dev/null +++ b/docs/source/llm/mobile/getting-started.md @@ -0,0 +1,16 @@ +# Getting Started with LLMs via ExecuTorch + + +## Simple “Hello World” example + + +## Use Mobile Acceleration + + +## Quantization via XNNPACKQuantizer + + +## Debugging and Profiling + + +## Build Mobile LLM chat App Examples diff --git a/docs/source/llm/mobile/index.rst b/docs/source/llm/mobile/index.rst new file mode 100644 index 0000000000..8012c01f8c --- /dev/null +++ b/docs/source/llm/mobile/index.rst @@ -0,0 +1,12 @@ +Enabling LLMs on Mobile +======================= + +This section will walk you through + +.. toctree:: + :maxdepth: 1 + + benchmarks + getting-started + customization-examples + validating-other-models diff --git a/docs/source/llm/mobile/validating-other-models.md b/docs/source/llm/mobile/validating-other-models.md new file mode 100644 index 0000000000..6097241777 --- /dev/null +++ b/docs/source/llm/mobile/validating-other-models.md @@ -0,0 +1,3 @@ +# Validating other models + +## Exportability results diff --git a/docs/source/runtime-build-and-cross-compilation.md b/docs/source/runtime-build-and-cross-compilation.md index 31d9b38b71..22246b8f8c 100644 --- a/docs/source/runtime-build-and-cross-compilation.md +++ b/docs/source/runtime-build-and-cross-compilation.md @@ -140,7 +140,7 @@ Assuming Android NDK is available, run: rm -rf cmake-android-out && mkdir cmake-android-out && cd cmake-android-out # point -DCMAKE_TOOLCHAIN_FILE to the location where ndk is installed -# Run `which buck2`, if it returns empty (meaning the system doesn't know where buck2 is installed), pass in pass in this flag `-DBUCK2=/path/to/buck2` pointing to buck2 +# Run `which buck2`, if it returns empty (meaning the system doesn't know where buck2 is installed), pass in this flag `-DBUCK2=/path/to/buck2` pointing to buck2 cmake -DCMAKE_TOOLCHAIN_FILE=/Users/{user_name}/Library/Android/sdk/ndk/25.2.9519653/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a .. cd .. diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index b5c5c00f47..49fc2c42b7 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -130,11 +130,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: aten_dialect: ExportedProgram = export(f, example_args) # Works correctly -print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3))) +print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3))) # Errors try: - print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2))) + print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2))) except Exception: tb.print_exc() @@ -175,18 +175,18 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Now let's try running the model with different shapes: # Works correctly -print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3))) -print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2))) +print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3))) +print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2))) # Errors because it violates our constraint that input 0, dim 1 <= 10 try: - print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15))) + print(aten_dialect.module()(torch.ones(3, 15), torch.ones(3, 15))) except Exception: tb.print_exc() # Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1 try: - print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2))) + print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 2))) except Exception: tb.print_exc() @@ -287,23 +287,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # there is only one program, it will by default be saved to the name "forward". -def encode(x): - return torch.nn.functional.linear(x, torch.randn(5, 10)) +class Encode(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.linear(x, torch.randn(5, 10)) -def decode(x): - return torch.nn.functional.linear(x, torch.randn(10, 5)) +class Decode(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.linear(x, torch.randn(10, 5)) encode_args = (torch.randn(1, 10),) aten_encode: ExportedProgram = export( - capture_pre_autograd_graph(encode, encode_args), + capture_pre_autograd_graph(Encode(), encode_args), encode_args, ) decode_args = (torch.randn(1, 5),) aten_decode: ExportedProgram = export( - capture_pre_autograd_graph(decode, decode_args), + capture_pre_autograd_graph(Decode(), decode_args), decode_args, ) @@ -486,17 +488,18 @@ def forward(self, x): # ``LoweredBackendModule`` for each of those subgraphs. -def f(a, x, b): - y = torch.mm(a, x) - z = y + b - a = z - a - y = torch.mm(a, x) - z = y + b - return z +class Foo(torch.nn.Module): + def forward(self, a, x, b): + y = torch.mm(a, x) + z = y + b + a = z - a + y = torch.mm(a, x) + z = y + b + return z example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) -pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args) +pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) exported_program = edge_program.exported_program() @@ -520,17 +523,18 @@ def f(a, x, b): # call ``to_backend`` on it: -def f(a, x, b): - y = torch.mm(a, x) - z = y + b - a = z - a - y = torch.mm(a, x) - z = y + b - return z +class Foo(torch.nn.Module): + def forward(self, a, x, b): + y = torch.mm(a, x) + z = y + b + a = z - a + y = torch.mm(a, x) + z = y + b + return z example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) -pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args) +pre_autograd_aten_dialect = capture_pre_autograd_graph(Foo(), example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) exported_program = edge_program.exported_program() diff --git a/examples/apple/coreml/README.md b/examples/apple/coreml/README.md index f576476d5c..a10f3efcc9 100644 --- a/examples/apple/coreml/README.md +++ b/examples/apple/coreml/README.md @@ -36,7 +36,7 @@ cd executorch python3 -m examples.portable.scripts.export -h # Generates ./add_coreml_all.pte file if successful. -python3 -m examples.apple.coreml.scripts.export_and_delegate --model_name add +python3 -m examples.apple.coreml.scripts.export --model_name add ``` 4. Once we have the **Core ML** delegated model binary (pte) file, then let's run it with the **ExecuTorch** runtime using the `coreml_executor_runner`. @@ -52,7 +52,7 @@ cd executorch ``` ## Frequently encountered errors and resolution. -- The `examples.apple.coreml.scripts.export_and_delegate` could fail if the model is not supported by the **Core ML** backend. The following models from the examples models list (` python3 -m examples.portable.scripts.export -h`)are currently supported by the **Core ML** backend. +- The `examples.apple.coreml.scripts.export` could fail if the model is not supported by the **Core ML** backend. The following models from the examples models list (` python3 -m examples.portable.scripts.export -h`)are currently supported by the **Core ML** backend. ``` add diff --git a/examples/apple/coreml/scripts/export_and_delegate.py b/examples/apple/coreml/scripts/export.py similarity index 80% rename from examples/apple/coreml/scripts/export_and_delegate.py rename to examples/apple/coreml/scripts/export.py index 51f9981210..65d7d840a8 100644 --- a/examples/apple/coreml/scripts/export_and_delegate.py +++ b/examples/apple/coreml/scripts/export.py @@ -10,8 +10,14 @@ import executorch.exir as exir +import torch + from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition.coreml_partitioner import ( + CoreMLPartitioner, +) + from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.sdk.etrecord import generate_etrecord @@ -30,6 +36,37 @@ _check_ir_validity=False, ) +compute_units = ["cpu_only", "cpu_and_gpu", "cpu_and_ane", "all"] + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + + parser.add_argument( + "-c", + "--compute_units", + required=False, + default="all", + help=f"Provide compute units. Valid ones: {compute_units}", + ) + parser.add_argument("--use_partitioner", action=argparse.BooleanOptionalAction) + parser.add_argument("--generate_etrecord", action=argparse.BooleanOptionalAction) + parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction) + + args = parser.parse_args() + return args + + +def partition_module_to_coreml(module): + module = module.eval() + def lower_module_to_coreml(module, compute_units): module = module.eval() @@ -81,30 +118,8 @@ def save_processed_bytes(processed_bytes, model_name, compute_units): return -compute_units = ["cpu_only", "cpu_and_gpu", "cpu_and_ane", "all"] - if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-m", - "--model_name", - required=True, - help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", - ) - - parser.add_argument( - "-c", - "--compute_units", - required=False, - default="all", - help=f"Provide compute units. Valid ones: {compute_units}", - ) - - parser.add_argument("--generate_etrecord", action=argparse.BooleanOptionalAction) - - parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction) - - args = parser.parse_args() + args = parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError( @@ -122,15 +137,22 @@ def save_processed_bytes(processed_bytes, model_name, compute_units): *MODEL_NAME_TO_MODEL[args.model_name] ) - lowered_module, edge_copy = lower_module_to_coreml( - model, - args.compute_units, - ) - - exec_program = export_lowered_module_to_executorch_program( - lowered_module, - example_inputs, - ) + if args.use_partitioner: + model.eval() + exir_program_aten = torch.export.export(model, example_inputs) + edge_program_manager = exir.to_edge(exir_program_aten) + edge_copy = copy.deepcopy(edge_program_manager) + delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner()) + exec_program = delegated_program_manager.to_executorch() + else: + lowered_module, edge_copy = lower_module_to_coreml( + model, + args.compute_units, + ) + exec_program = export_lowered_module_to_executorch_program( + lowered_module, + example_inputs, + ) save_executorch_program(exec_program, args.model_name, args.compute_units) generate_etrecord(f"{args.model_name}_coreml_etrecord.bin", edge_copy, exec_program) diff --git a/examples/apple/mps/executor_runner/targets.bzl b/examples/apple/mps/executor_runner/targets.bzl index cacc81b43e..b14a8105e4 100644 --- a/examples/apple/mps/executor_runner/targets.bzl +++ b/examples/apple/mps/executor_runner/targets.bzl @@ -25,7 +25,7 @@ def define_common_targets(): "//executorch/runtime/executor:program", "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/data_loader:file_data_loader", - "//executorch/kernels/portable:generated_lib_all_ops", + "//executorch/kernels/portable:generated_lib", "//executorch/extension/data_loader:file_data_loader", "//executorch/sdk/etdump:etdump_flatcc", "//executorch/extension/data_loader:buffer_data_loader", diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 34f22cc288..866bc85a3a 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -143,13 +143,10 @@ if not args.use_fp16: extension = "fp32" model_name = f"{model_name}_{extension}" - program_buffer = bundled_program_buffer - else: - program_buffer = executorch_program.buffer if args.generate_etrecord: etrecord_path = "etrecord.bin" logging.info("generating etrecord.bin") generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program) - save_pte_program(program_buffer, model_name) + save_pte_program(executorch_program, model_name) diff --git a/examples/arm/ethos-u-setup/ethos-u-vela/patches/0001-Improve-rescale-codegen-for-TOSA.patch b/examples/arm/ethos-u-setup/ethos-u-vela/patches/0001-Improve-rescale-codegen-for-TOSA.patch deleted file mode 100644 index e131ca76ee..0000000000 --- a/examples/arm/ethos-u-setup/ethos-u-vela/patches/0001-Improve-rescale-codegen-for-TOSA.patch +++ /dev/null @@ -1,129 +0,0 @@ -From ef07230fbb15edbf27ecaf48994fb157430a5e7c Mon Sep 17 00:00:00 2001 -From: Rob Elliott -Date: Thu, 5 Oct 2023 16:45:42 +0000 -Subject: [PATCH] Improve rescale codegen for TOSA - -Signed-off-by: Rob Elliott ---- - ethosu/vela/tosa_graph_optimiser.py | 56 +++++++++++------------------ - ethosu/vela/tosa_mapping.py | 2 +- - 2 files changed, 22 insertions(+), 36 deletions(-) - -diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py -index df6b575..b2e3697 100644 ---- a/ethosu/vela/tosa_graph_optimiser.py -+++ b/ethosu/vela/tosa_graph_optimiser.py -@@ -337,7 +337,8 @@ def rewrite_concat(op): - - def remove_memory_ops(op, arch): - if op.run_on_npu and op.type in (Op.Reshape, Op.Identity): -- bypass_memory_only_ops(op) -+ # TODO: is this ok - function doesn't use arch or nng -+ bypass_memory_only_ops(op, arch, None) - - - def rewrite_activation(op, arch, nng): -@@ -357,7 +358,6 @@ def rewrite_activation(op, arch, nng): - - return op - -- - def rewrite_rescale(op, arch, nng): - if op.type == Op.Rescale: - ifm = op.ifm -@@ -368,7 +368,7 @@ def rewrite_rescale(op, arch, nng): - prev_op = ifm.ops[0] - - # TODO currently not supported -- assert len(ifm.consumer_list) == 1 -+ #assert len(ifm.consumer_list) == 1 - - input_zp = op.attrs["input_zp"] - output_zp = op.attrs["output_zp"] -@@ -390,6 +390,9 @@ def rewrite_rescale(op, arch, nng): - assert False - ifm.quantization.zero_point = input_zp - ofm.quantization.zero_point = output_zp -+ -+ assert False == per_channel, "Don't like per_channel!" -+ - for s, m in zip(shift, multiplier): - # TODO these are the TOSA limitations - assert m >= 0 -@@ -403,45 +406,28 @@ def rewrite_rescale(op, arch, nng): - else: - rounding_mode = RoundingMode.HalfUp - -- if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: -+ fuse = len(ifm.ops) == 1 and prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() -+ if fuse: -+ # TODO: ERROR: bias.values didn't exist for an op like Add - presumably not a capability of that op - assert len(multiplier) == len(shift) == len(prev_op.bias.values) -- -- if ifm.dtype == DataType.int32 and per_channel: -- prev_op.explicit_scaling = explicit_scaling -- prev_op.rounding_mode = rounding_mode -- -- # Bypass op -- prev_op.set_output_tensor(ofm) -- DebugDatabase.add_optimised(op, prev_op) -- return op -- else: -- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) -- assert False -- # TODO which are the cases we need to and can do standalone Rescale? -- # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops? -- # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE? -- # limited to these at the moment: -- elif ( -- (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) -- or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8) -- or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8) -- ): -- # Create NOP performing the RESCALE -+ # TODO: generate replacement fusion code from below -+ assert False, "Fusion possible but i've not implemented it" -+ else: -+ # Generate Rescale behaviour attached to a compatible NOP -+ # TODO: I assume this attaches a new operator into the graph?? - avgpool_op = replace_rescale_with_avg_pool(op) - avgpool_op.rounding_mode = rounding_mode -- -+ - if per_channel: -- # TODO -- avgpool_op.explicit_scaling = explicit_scaling -- print("Warning, unsupported TOSA Rescale") -- assert False -+ assert False, "Assert above removed but still not implemented... :/" - else: - avgpool_op.explicit_scaling = explicit_scaling -- else: -- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) -- assert False -- return op - -+ #print( len(multiplier), len(shift), len(prev_op.get_bias_tensors()) ) -+ #print( ifm.dtype, "PC:", per_channel, op.type ) -+ #print( ifm.dtype, ofm.dtype ) -+ -+ return op - - def convert_pad_in_width(op): - """ -diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py -index 2dafd81..ed5aa2e 100644 ---- a/ethosu/vela/tosa_mapping.py -+++ b/ethosu/vela/tosa_mapping.py -@@ -148,7 +148,7 @@ transpose_conv_attrs = AttrSerializer( - ) - transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),)) - axis_attrs = AttrSerializer("AxisAttribute", ("axis",)) --reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) -+reshape_attrs = AttrSerializer("ReshapeAttribute", (("newShape", is_vec),)) - slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec))) - tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) - resize_attrs = AttrSerializer( --- -2.41.0 - diff --git a/examples/arm/ethos-u-setup/ethos-u-vela/patches/0002-Use-TOSA-0.80.1_new-12f0e94aca6c17d0c6dc9b463277ab38.patch b/examples/arm/ethos-u-setup/ethos-u-vela/patches/0002-Use-TOSA-0.80.1_new-12f0e94aca6c17d0c6dc9b463277ab38.patch deleted file mode 100644 index 5d9a560b08..0000000000 --- a/examples/arm/ethos-u-setup/ethos-u-vela/patches/0002-Use-TOSA-0.80.1_new-12f0e94aca6c17d0c6dc9b463277ab38.patch +++ /dev/null @@ -1,26 +0,0 @@ -From 394636ef2063e386f54abde094298bb0e40a2cb7 Mon Sep 17 00:00:00 2001 -From: Zingo Andersen -Date: Sat, 20 Jan 2024 10:34:45 +0100 -Subject: [PATCH 2/2] Use TOSA 0.80.1 - -Signed-off-by: Zingo Andersen ---- - ethosu/vela/tosa_reader.py | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py -index 56af59d..7cb2bf3 100644 ---- a/ethosu/vela/tosa_reader.py -+++ b/ethosu/vela/tosa_reader.py -@@ -294,7 +294,7 @@ class TosaGraph: - def check_version(self, tosa_graph): - version = tosa_graph.Version() - version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}" -- if version_str != "0.80.0": -+ if version_str != "0.80.1": - print(f"Unsupported TOSA version: {version_str}") - assert False - --- -2.25.1 - diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 991772166a..d1eeb84173 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -215,7 +215,7 @@ function setup_vela() { if [[ ! -e ethos-u-vela ]]; then git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela repo_dir="${root_dir}/ethos-u-vela" - base_rev=00a15db3e1a188b25065d095152d701f4394cdc5 + base_rev=78b9412b07e0a46e58e8ecb9da8d661399c006a5 patch_repo fi cd "${root_dir}/ethos-u-vela" diff --git a/examples/demo-apps/android/ExecuTorchDemo/README.md b/examples/demo-apps/android/ExecuTorchDemo/README.md index 802e1d9493..cffac7e612 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/README.md +++ b/examples/demo-apps/android/ExecuTorchDemo/README.md @@ -116,7 +116,7 @@ cmake .. \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ -DANDROID_ABI=arm64-v8a \ -DBUCK2=/tmp/buck2 \ - -DEXECUTORCH_BUILD_ANDROID_DEMO_APP_JNI=ON \ + -DEXECUTORCH_BUILD_ANDROID_JNI=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON \ -DEXECUTORCH_BUILD_FLATC=OFF \ -DEXECUTORCH_BUILD_QNN=ON \ diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java index 8c4dd8f8de..93235720b4 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/ClassificationActivity.java @@ -71,7 +71,7 @@ public void run() { TensorImageUtils.TORCHVISION_NORM_STD_RGB); // running the model - final Tensor outputTensor = module.forward(EValue.from(inputTensor)).toTensor(); + final Tensor outputTensor = module.forward(EValue.from(inputTensor))[0].toTensor(); // getting tensor content as java array of floats final float[] scores = outputTensor.getDataAsFloatArray(); diff --git a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java index cdb1ac1983..25f1853e96 100644 --- a/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java +++ b/examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java @@ -189,7 +189,7 @@ public void run() { final float[] inputs = inputTensor.getDataAsFloatArray(); final long startTime = SystemClock.elapsedRealtime(); - Tensor outputTensor = mModule.forward(EValue.from(inputTensor)).toTensor(); + Tensor outputTensor = mModule.forward(EValue.from(inputTensor))[0].toTensor(); final long inferenceTime = SystemClock.elapsedRealtime() - startTime; Log.d("ImageSegmentation", "inference time (ms): " + inferenceTime); diff --git a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts index 3b5e415e19..7297aa60ea 100644 --- a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts +++ b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts @@ -12,11 +12,11 @@ plugins { } android { - namespace = "com.example.executorchdemo" + namespace = "com.example.executorchllamademo" compileSdk = 34 defaultConfig { - applicationId = "com.example.executorchdemo" + applicationId = "com.example.executorchllamademo" minSdk = 24 targetSdk = 33 versionCode = 1 diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java deleted file mode 100644 index c24c367860..0000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchdemo/MainActivity.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.app.Activity; -import android.app.AlertDialog; -import android.content.Context; -import android.os.Bundle; -import android.widget.Button; -import android.widget.EditText; -import android.widget.TextView; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import org.pytorch.executorch.LlamaCallback; -import org.pytorch.executorch.LlamaModule; - -public class MainActivity extends Activity implements Runnable, LlamaCallback { - private EditText mEditTextMessage; - private TextView mTextViewChat; - private Button mSendButton; - private Button mStopButton; - private Button mModelButton; - private LlamaModule mModule = null; - private String mResult = null; - - private static String assetFilePath(Context context, String assetName) throws IOException { - File file = new File(context.getFilesDir(), assetName); - if (file.exists() && file.length() > 0) { - return file.getAbsolutePath(); - } - - try (InputStream is = context.getAssets().open(assetName)) { - try (OutputStream os = new FileOutputStream(file)) { - byte[] buffer = new byte[4 * 1024]; - int read; - while ((read = is.read(buffer)) != -1) { - os.write(buffer, 0, read); - } - os.flush(); - } - return file.getAbsolutePath(); - } - } - - @Override - public void onResult(String result) { - System.out.println("onResult: " + result); - mResult = result; - run(); - } - - private void setModel(String modelPath, String tokenizerPath) { - try { - String model = MainActivity.assetFilePath(getApplicationContext(), modelPath); - String tokenizer = MainActivity.assetFilePath(getApplicationContext(), tokenizerPath); - mModule = new LlamaModule(model, tokenizer, 0.8f); - } catch (IOException e) { - finish(); - } - } - - private void setLocalModel(String modelPath, String tokenizerPath) { - mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); - } - - private void modelDialog() { - AlertDialog.Builder builder = new AlertDialog.Builder(this); - builder.setTitle("Select a Model"); - builder.setSingleChoiceItems( - new String[] {"stories", "language"}, - -1, - new android.content.DialogInterface.OnClickListener() { - public void onClick(android.content.DialogInterface dialog, int item) { - switch (item) { - case 0: - setModel("stories110M.pte", "tokenizer.bin"); - break; - case 1: - setLocalModel("/data/local/tmp/language.pte", "/data/local/tmp/language.bin"); - break; - } - mEditTextMessage.setText(""); - mTextViewChat.setText(""); - dialog.dismiss(); - } - }); - AlertDialog alert = builder.create(); - alert.show(); - } - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_main); - - mEditTextMessage = findViewById(R.id.editTextMessage); - mTextViewChat = findViewById(R.id.textViewChat); - mSendButton = findViewById(R.id.sendButton); - mStopButton = findViewById(R.id.stopButton); - mModelButton = findViewById(R.id.modelButton); - - mSendButton.setOnClickListener( - view -> { - String prompt = mEditTextMessage.getText().toString(); - mTextViewChat.append(prompt); - mEditTextMessage.setText(""); - Runnable runnable = - new Runnable() { - @Override - public void run() { - mModule.generate(prompt, MainActivity.this); - } - }; - new Thread(runnable).start(); - }); - - mStopButton.setOnClickListener( - view -> { - mModule.stop(); - }); - - mModelButton.setOnClickListener( - view -> { - mModule.stop(); - modelDialog(); - }); - - setModel("stories110M.pte", "tokenizer.bin"); - } - - @Override - public void run() { - runOnUiThread( - new Runnable() { - @Override - public void run() { - mTextViewChat.append(mResult); - } - }); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java new file mode 100644 index 0000000000..dd7cbfe50f --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -0,0 +1,215 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.Activity; +import android.app.ActivityManager; +import android.app.AlertDialog; +import android.content.Context; +import android.os.Bundle; +import android.widget.Button; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.ListView; +import java.io.File; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +public class MainActivity extends Activity implements Runnable, LlamaCallback { + private EditText mEditTextMessage; + private Button mSendButton; + private ImageButton mModelButton; + private ListView mMessagesView; + private MessageAdapter mMessageAdapter; + private LlamaModule mModule = null; + private Message mResultMessage = null; + + private int mNumTokens = 0; + private long mRunStartTime = 0; + private String mModelFilePath = ""; + private String mTokenizerFilePath = ""; + + @Override + public void onResult(String result) { + System.out.println("onResult: " + result); + mResultMessage.appendText(result); + mNumTokens++; + run(); + } + + private static String[] listLocalFile(String path, String suffix) { + File directory = new File(path); + if (directory.exists() && directory.isDirectory()) { + File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix)); + String[] result = new String[files.length]; + for (int i = 0; i < files.length; i++) { + if (files[i].isFile() && files[i].getName().endsWith(suffix)) { + result[i] = files[i].getAbsolutePath(); + } + } + return result; + } + return null; + } + + private void setLocalModel(String modelPath, String tokenizerPath) { + long runStartTime = System.currentTimeMillis(); + mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); + int loadResult = mModule.load(); + if (loadResult != 0) { + AlertDialog.Builder builder = new AlertDialog.Builder(this); + builder.setTitle("Load failed: " + loadResult); + AlertDialog alert = builder.create(); + alert.show(); + } + + long runDuration = System.currentTimeMillis() - runStartTime; + String modelInfo = + "Model path: " + + modelPath + + "\nTokenizer path: " + + tokenizerPath + + "\nModel loaded time: " + + runDuration + + " ms"; + Message modelLoadedMessage = new Message(modelInfo, false); + mMessageAdapter.add(modelLoadedMessage); + mMessageAdapter.notifyDataSetChanged(); + } + + private String memoryInfo() { + final ActivityManager am = (ActivityManager) getSystemService(Context.ACTIVITY_SERVICE); + ActivityManager.MemoryInfo memInfo = new ActivityManager.MemoryInfo(); + am.getMemoryInfo(memInfo); + return "Total RAM: " + + Math.floorDiv(memInfo.totalMem, 1000000) + + " MB. Available RAM: " + + Math.floorDiv(memInfo.availMem, 1000000) + + " MB."; + } + + private void modelDialog() { + String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte"); + String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin"); + AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); + modelPathBuilder.setTitle("Select model path"); + AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); + tokenizerPathBuilder.setTitle("Select tokenizer path"); + modelPathBuilder.setSingleChoiceItems( + binFiles, + -1, + (dialog, item) -> { + mTokenizerFilePath = binFiles[item]; + mEditTextMessage.setText(""); + dialog.dismiss(); + tokenizerPathBuilder.create().show(); + }); + + tokenizerPathBuilder.setSingleChoiceItems( + pteFiles, + -1, + (dialog, item) -> { + mModelFilePath = pteFiles[item]; + setLocalModel(mModelFilePath, mTokenizerFilePath); + dialog.dismiss(); + }); + + modelPathBuilder.create().show(); + } + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + mEditTextMessage = findViewById(R.id.editTextMessage); + mSendButton = findViewById(R.id.sendButton); + mModelButton = findViewById(R.id.modelButton); + mMessagesView = findViewById(R.id.messages_view); + mMessageAdapter = new MessageAdapter(this, R.layout.sent_message); + mMessagesView.setAdapter(mMessageAdapter); + mModelButton.setOnClickListener( + view -> { + mModule.stop(); + mMessageAdapter.clear(); + mMessageAdapter.notifyDataSetChanged(); + modelDialog(); + }); + + setLocalModel("/data/local/tmp/llama/stories110M.pte", "/data/local/tmp/llama/tokenizer.bin"); + onModelRunStopped(); + } + + private void onModelRunStarted() { + mSendButton.setText("Stop"); + mSendButton.setOnClickListener( + view -> { + mModule.stop(); + }); + + mRunStartTime = System.currentTimeMillis(); + } + + private void onModelRunStopped() { + setTitle(memoryInfo()); + long runDuration = System.currentTimeMillis() - mRunStartTime; + if (mResultMessage != null) { + mResultMessage.setTokensPerSecond(1.0f * mNumTokens / (runDuration / 1000.0f)); + } + mSendButton.setText("Generate"); + mSendButton.setOnClickListener( + view -> { + String prompt = mEditTextMessage.getText().toString(); + mMessageAdapter.add(new Message(prompt, true)); + mMessageAdapter.notifyDataSetChanged(); + mEditTextMessage.setText(""); + mResultMessage = new Message("", false); + mMessageAdapter.add(mResultMessage); + Runnable runnable = + new Runnable() { + @Override + public void run() { + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStarted(); + } + }); + + mModule.generate(prompt, MainActivity.this); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStopped(); + } + }); + } + }; + new Thread(runnable).start(); + }); + mNumTokens = 0; + mRunStartTime = 0; + mMessageAdapter.notifyDataSetChanged(); + } + + @Override + public void run() { + runOnUiThread( + new Runnable() { + @Override + public void run() { + mMessageAdapter.notifyDataSetChanged(); + setTitle(memoryInfo()); + } + }); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java new file mode 100644 index 0000000000..81b77b1aba --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class Message { + private String text; + private boolean isSent; + private float tokensPerSecond; + + public Message(String text, boolean isSent) { + this.text = text; + this.isSent = isSent; + } + + public String getText() { + return text; + } + + public void appendText(String text) { + this.text += text; + } + + public boolean getIsSent() { + return isSent; + } + + public void setTokensPerSecond(float tokensPerSecond) { + this.tokensPerSecond = tokensPerSecond; + } + + public float getTokensPerSecond() { + return tokensPerSecond; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java new file mode 100644 index 0000000000..656da1967d --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ArrayAdapter; +import android.widget.TextView; + +public class MessageAdapter extends ArrayAdapter { + public MessageAdapter(android.content.Context context, int resource) { + super(context, resource); + } + + @Override + public View getView(int position, View convertView, ViewGroup parent) { + Message currentMessage = getItem(position); + + int layoutIdForListItem = + currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; + View listItemView = + LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); + TextView messageTextView = listItemView.findViewById(R.id.message_text); + messageTextView.setText(currentMessage.getText()); + + if (currentMessage.getTokensPerSecond() > 0) { + TextView tokensView = listItemView.findViewById(R.id.tokens_per_second); + tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s"); + } + + return listItemView; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml new file mode 100644 index 0000000000..ea2d1bbfa1 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml new file mode 100644 index 0000000000..e8d13ca4e1 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml new file mode 100644 index 0000000000..afbe22da80 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml @@ -0,0 +1,5 @@ + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml index f769578d33..089acb572b 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml @@ -1,44 +1,44 @@ - - - + -