diff --git a/.bazelrc b/.bazelrc index e9fc2d4eb20a55..9de6b6e0c2bd54 100644 --- a/.bazelrc +++ b/.bazelrc @@ -526,34 +526,9 @@ build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_c build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=cuda +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true -build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true -build:rbe_linux_cuda_nvcc --config=tensorrt -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" -build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9" -build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3" -# These you may need to change for your own GCP project. -common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl" -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -692,19 +667,39 @@ build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda build:release_cpu_macos --config=avx_linux test:release_cpu_macos --config=release_base -# Build configs for macOS ARM CPUs +# Base build configs for macOS +build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer +build:release_macos_base --define=no_nccl_support=true --output_filter=^$ + +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Build configs for macOS Arm64 +build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 -# Set DEVELOPER_DIR to select a version of Xcode. -build:release_macos_arm64 --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer -build:release_macos_arm64 --define=no_nccl_support=true -# Suppress all warning messages -build:release_macos_arm64 --output_filter=^$ -# Disable MKL build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0 # Target Moneterey as the minimum compatible OS version build:release_macos_arm64 --macos_minimum_os=12.0 build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0 +# Base test configs for macOS +test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS +test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors +test:release_macos_base --build_tests_only --keep_going +test:release_macos_base --flaky_test_attempts=3 + +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + +# Test configs for macOS Arm64 +test:release_macos_arm64 --config=release_macos_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true @@ -723,10 +718,14 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs -# push to the cache. +# push to the cache. For macOS, use --config=tf_public_macos_cache build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials +# Public cache for macOS builds +build:tf_public_macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to TF's CI system. +build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_local_results=true --google_default_credentials # END TF CACHE HELPER OPTIONS # BEGIN TF TEST SUITE OPTIONS @@ -743,22 +742,27 @@ build:linux_libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow. test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --test_lang_filters=py -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -766,21 +770,53 @@ test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorf test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA PYCPP: test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +# CROSS-COMPILE ARM64 PYCPP +test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +# Tests that fail only when cross-compiled +test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # END TF TEST SUITE OPTIONS + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_linux_arm64 --host_cpu=k8 +build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE configs +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_base +build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local +test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors +test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only +# END LINUX AARCH64 CROSS-COMPILE CONFIGS diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index 96467ebaeb35a9..3b07683008391d 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -20,12 +20,6 @@ on: branches: - master - r2.** - pull_request: - types: [labeled, opened, synchronize, reopened] - branches: - - master - - r2.** - permissions: contents: read diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index bb39d60168e08d..fb7366768436c5 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -27,6 +27,7 @@ permissions: jobs: scan-scheduled: + if: github.repository == 'tensorflow/tensorflow' uses: "google/osv-scanner/.github/workflows/osv-scanner-reusable.yml@main" with: scan-args: |- @@ -36,4 +37,4 @@ jobs: --lockfile=requirements.txt:./requirements_lock_3_12.txt --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.requirements.txt --lockfile=requirements.txt:./ci/official/containers/linux_arm64/jax.requirements.txt - --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.usertools/test.requirements.txt \ No newline at end of file + --lockfile=requirements.txt:./ci/official/containers/linux_arm64/devel.usertools/test.requirements.txt diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index 84118acca683fd..e439c0f180ed44 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -31,7 +31,7 @@ jobs: pull-requests: write steps: - name: Awaiting response issues - uses: actions/stale@v7 + uses: actions/stale@6f05e4244c9a0b2ed3401882b05d701dd0a7289b # v7.0.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' @@ -59,7 +59,7 @@ jobs: close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further." repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Contribution issues - uses: actions/stale@v7 + uses: actions/stale@6f05e4244c9a0b2ed3401882b05d701dd0a7289b # v7.0.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' diff --git a/RELEASE.md b/RELEASE.md index 75350aeccc5542..6ee5c0ca55fa5f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -9,6 +9,12 @@ * * +* `tf.summary.trace_on` now takes a `profiler_outdir` argument. This must be set + if `profiler` arg is set to `True`. + * `tf.summary.trace_export`'s `profiler_outdir` arg is now a no-op. Enabling + the profiler now requires setting `profiler_outdir` in `trace_on`. + + ### Known Caveats * @@ -32,6 +38,17 @@ * Added support for `stablehlo.multiply`. * Added support for `stablehlo.maximum`. * Added support for `stablehlo.minimum`. + * Added boolean parameter support for `tfl.gather_nd`. + +* `tf.CheckpointOptions` + * It now takes in a new argument called `experimental_sharding_callback`. + This is a callback function wrapper that will be executed to determine how + tensors will be split into shards when the saver writes the checkpoint + shards to disk. `tf.train.experimental.ShardByTaskPolicy` is the default + sharding behavior, but `tf.train.experimental.MaxShardSizePolicy` can be + used to shard the checkpoint with a maximum shard file size. Users with + advanced use cases can also write their own custom + `tf.train.experimental.ShardingCallback`s. ## Keras @@ -48,6 +65,9 @@ table maintained by the layer. If this layer is not used in conjunction with `UpdateEmbeddingCallback` the behavior of the layer would be same as `keras.layers.Embedding`. +* `keras.optimizers.Adam` + * Added the option to set adaptive epsilon to match implementations with Jax + and PyTorch equivalents. ### Breaking Changes @@ -77,6 +97,39 @@ This release contains contributions from many people at Google, as well as: , , , , , +# Release 2.15.0.post1 + +## TensorFlow + +### Bug Fixes and Other Changes + +* Hot-fix was needed for an issue affecting the TensorFlow installation + process. + * TensorFlow 2.15.0 Python package was requesting `tensorrt`-related + packages that cannot be found unless the user installs them beforehand + or provides additional installation flags. + * This dependency affected anyone installing TensorFlow 2.15 alongside + NVIDIA CUDA dependencies via `pip install tensorflow[and-cuda]`. + * Depending on the installation method, TensorFlow 2.14 would be installed + instead of 2.15, or users could receive an installation error due to + those missing dependencies. +* TensorFlow 2.15.0.post1 is being released for Linux x86_64 to resolve this + issue as quickly as possible. + * This version removes the `tensorrt` Python package dependencies from the + tensorflow[and-cuda] installation method to ensure `pip install + tensorflow[and-cuda]` works as originally intended for TensorFlow 2.15. + * Support for TensorRT is otherwise unaffected as long as TensorRT is + already installed on the system. +* Using .post1 instead of a full minor release allowed us to push this release + out quickly. However, please note the following caveat: + * For users wishing to pin their Python dependency in a requirements file + or other situation, under Python's version specification rules, + `tensorflow[and-cuda]==2.15.0` will not install this fixed version. + Please use `==2.15.0.post1` to specify this exact version on Linux + platforms, or a fuzzy version specification, such as `==2.15.*`, to + specify the most recent compatible version of TensorFlow 2.15 on all + platforms. + # Release 2.15.0 ## TensorFlow @@ -164,29 +217,26 @@ This release contains contributions from many people at Google, as well as: * Provided a new `experimental_skip_saver` argument which, if specified, will suppress the addition of `SavedModel`-native save and restore ops to the `SavedModel`, for cases where users already build custom save/restore ops and checkpoint formats for the model being saved, and the creation of the SavedModel-native save/restore ops simply cause longer model serialization times. -* `tf.math.bincount` - * Updated documentation. Fixed "[Bincount doesn't check the tensor type](https://github.com/tensorflow/tensorflow/issues/56499)" and some other corner cases. - -## Keras - -### Breaking Changes - -### Known Caveats - -### Major Features and Improvements - -### Bug Fixes and Other Changes - * Add ops to `tensorflow.raw_ops` that were missing. + * `tf.CheckpointOptions` * It now takes in a new argument called `experimental_write_callbacks`. These are callbacks that will be executed after a saving event finishes writing the checkpoint file. + * Add an option `disable_eager_executer_streaming_enqueue` to `tensorflow.ConfigProto.Experimental` to control the eager runtime's behavior around parallel remote function invocations; when set to `True`, the eager runtime will be allowed to execute multiple function invocations in parallel. + * `tf.constant_initializer` - * It now takes a new argument called `support_partition`. If True, constant_initializers can create sharded variables. This is disabled by default similar to existing behavior. + * It now takes a new argument called `support_partition`. If True, constant_initializers can create sharded variables. This is disabled by default, similar to existing behavior. * `tf.lite` * Added support for `stablehlo.scatter`. +* `tf.estimator` + * The tf.estimator API removal is in progress and will be targeted for the 2.16 release. + +## Keras + +* This will be the final release before the launch of Keras 3.0, when Keras will become multi-backend. For the compatibility page and other info, please see: https://github.com/keras-team/keras-core + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: diff --git a/ci/official/any.sh b/ci/official/any.sh index 74e4caa666d259..8eae7cd85f3445 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -34,8 +34,10 @@ if [[ -n "${TF_ANY_SCRIPT:-}" ]]; then echo "source ci/official/envs/disable_all_uploads" >> any export TFCI=$(realpath any) "$TF_ANY_SCRIPT" -else +elif [[ -n "${TF_ANY_TARGETS:-}" ]]; then source "${BASH_SOURCE%/*}/utilities/setup.sh" - read -ra TARGETS_AS_ARRAY <<<"$TF_ANY_TARGETS" - tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" "${TF_ANY_MODE:-test}" "${TFCI_BAZEL_COMMON_ARGS[@]}" "${TARGETS_AS_ARRAY[@]}" + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS "${TF_ANY_MODE:-test}" $TFCI_BAZEL_COMMON_ARGS $TF_ANY_TARGETS +else + echo 'Looks like $TF_ANY_TARGETS are $TF_ANY_SCRIPT are both empty. That is an error.' + exit 1 fi diff --git a/ci/official/containers/linux_arm64/devel.usertools/aarch64.bazelrc b/ci/official/containers/linux_arm64/devel.usertools/aarch64.bazelrc index f41974b5b6ab7d..f2a08d60720f9a 100644 --- a/ci/official/containers/linux_arm64/devel.usertools/aarch64.bazelrc +++ b/ci/official/containers/linux_arm64/devel.usertools/aarch64.bazelrc @@ -49,7 +49,7 @@ test --test_summary=short test:nonpip_filters --test_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:nonpip_filters --build_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:nonpip_filters --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # "pip tests" run a similar suite of tests the "nonpip" tests, but do something # odd to attempt to validate the quality of the pip package. The wheel is @@ -70,10 +70,10 @@ test:pip_venv --python_path="/bazel_pip/bin/python3" test:pip_venv --define=no_tensorflow_py_deps=true test:pip --config=pip_venv # Yes, we don't exclude the gpu tests on pip for some reason. -test:pip_filters --test_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:pip_filters --build_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:pip_filters --test_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:pip_filters --build_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pip_filters --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium -test:pip --config=pip_filters -- //bazel_pip/tensorflow/... -//bazel_pip/tensorflow/python/integration_testing/... -//bazel_pip/tensorflow/compiler/tf2tensorrt/... -//bazel_pip/tensorflow/compiler/xrt/... -//bazel_pip/tensorflow/core/tpu/... -//bazel_pip/tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pip --config=pip_filters -- //bazel_pip/tensorflow/... -//bazel_pip/tensorflow/python/integration_testing/... -//bazel_pip/tensorflow/compiler/tf2tensorrt/... -//bazel_pip/tensorflow/core/tpu/... -//bazel_pip/tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test @@ -83,4 +83,4 @@ build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz test:pycpp_filters --test_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pycpp_filters --build_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pycpp_filters --test_lang_filters=cc,py --flaky_test_attempts=3 --test_size_filters=small,medium -test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/containers/linux_arm64/devel.usertools/aarch64_clang.bazelrc b/ci/official/containers/linux_arm64/devel.usertools/aarch64_clang.bazelrc index 50b3851db88ea0..0cb20a89b4bd7f 100644 --- a/ci/official/containers/linux_arm64/devel.usertools/aarch64_clang.bazelrc +++ b/ci/official/containers/linux_arm64/devel.usertools/aarch64_clang.bazelrc @@ -60,7 +60,7 @@ test --test_summary=short test:nonpip_filters --test_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:nonpip_filters --build_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:nonpip_filters --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # "pip tests" run a similar suite of tests the "nonpip" tests, but do something # odd to attempt to validate the quality of the pip package. The wheel is @@ -81,10 +81,10 @@ test:pip_venv --python_path="/bazel_pip/bin/python3" test:pip_venv --define=no_tensorflow_py_deps=true test:pip --config=pip_venv # Yes, we don't exclude the gpu tests on pip for some reason. -test:pip_filters --test_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:pip_filters --build_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:pip_filters --test_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:pip_filters --build_tag_filters=-nopip,-no_pip,-no_oss,-oss_serial,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pip_filters --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium -test:pip --config=pip_filters -- //bazel_pip/tensorflow/... -//bazel_pip/tensorflow/python/integration_testing/... -//bazel_pip/tensorflow/compiler/tf2tensorrt/... -//bazel_pip/tensorflow/compiler/xrt/... -//bazel_pip/tensorflow/core/tpu/... -//bazel_pip/tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pip --config=pip_filters -- //bazel_pip/tensorflow/... -//bazel_pip/tensorflow/python/integration_testing/... -//bazel_pip/tensorflow/compiler/tf2tensorrt/... -//bazel_pip/tensorflow/core/tpu/... -//bazel_pip/tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test @@ -94,4 +94,4 @@ build:libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow.tar.gz test:pycpp_filters --test_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pycpp_filters --build_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:pycpp_filters --test_lang_filters=cc,py --flaky_test_attempts=3 --test_size_filters=small,medium -test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index eb7938c8b3449d..a8e212ff47f005 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -1,9 +1,9 @@ -TFCI_BAZEL_BAZELRC_ARGS=() +# Note: this gets sourced in utilities/setup.sh +TFCI_BAZEL_BAZELRC_ARGS= +TFCI_BAZEL_COMMON_ARGS= TFCI_BAZEL_CONFIG_PREFIX= -TFCI_BAZEL_COMMON_ARGS=() -TFCI_PYTHON_VERSION= -TFCI_BUILD_PIP_PACKAGE_ARGS=() -TFCI_DOCKER_ARGS=() +TFCI_BUILD_PIP_PACKAGE_ARGS= +TFCI_DOCKER_ARGS= TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_IMAGE= TFCI_DOCKER_PULL_ENABLE=1 @@ -15,15 +15,24 @@ TFCI_LIB_SUFFIX= TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= TFCI_NVIDIA_SMI_ENABLE= TFCI_OUTPUT_DIR=build_output -TFCI_LIBTPU_DOWNLOAD_ENABLE=0 -TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=0 -TFCI_LIBTPU_DOWNLOAD_URL= +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS= +TFCI_PYTHON_VERSION= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI= TFCI_UPLOAD_LIB_URI= TFCI_UPLOAD_WHL_GCS_ENABLE= TFCI_UPLOAD_WHL_GCS_URI= -TFCI_UPLOAD_WHL_PYPI_ARGS=() +TFCI_UPLOAD_WHL_PYPI_ARGS= TFCI_UPLOAD_WHL_PYPI_ENABLE= +TFCI_WHL_AUDIT_ENABLE=1 +TFCI_WHL_AUDIT_PLAT= TFCI_WHL_BAZEL_TEST_ENABLE=1 +TFCI_WHL_SIZE_LIMIT= +TFCI_WHL_SIZE_LIMIT_ENABLE=1 +TFCI_MACOS_UPGRADE_PYENV_ENABLE= +TFCI_MACOS_INSTALL_BAZELISK_ENABLE= +TFCI_MACOS_INSTALL_BAZELISK_URL= +TFCI_MACOS_PYENV_INSTALL_ENABLE= +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE= +TFCI_MACOS_BAZEL_TEST_DIR_PATH= diff --git a/ci/official/envs/ci_nightly_uploads b/ci/official/envs/ci_nightly_uploads index ca6671f5ea3c59..7f62baf903c7e6 100644 --- a/ci/official/envs/ci_nightly_uploads +++ b/ci/official/envs/ci_nightly_uploads @@ -4,5 +4,5 @@ TFCI_UPLOAD_LIB_LATEST_ENABLE=1 TFCI_UPLOAD_LIB_LATEST_GCS_URI="gs://libtensorflow-nightly/latest" TFCI_UPLOAD_WHL_GCS_ENABLE=0 TFCI_UPLOAD_WHL_GCS_URI= -TFCI_UPLOAD_WHL_PYPI_ARGS=(--config-file="$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository pypi-warehouse) +TFCI_UPLOAD_WHL_PYPI_ARGS="--config-file=$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token --repository pypi-warehouse" TFCI_UPLOAD_WHL_PYPI_ENABLE=1 diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py310 b/ci/official/envs/continuous_linux_arm64_cpu_py310 index b8d7e5c3228356..5f8d16be1aaa6a 100644 --- a/ci/official/envs/continuous_linux_arm64_cpu_py310 +++ b/ci/official/envs/continuous_linux_arm64_cpu_py310 @@ -1,7 +1,6 @@ # This envrionment is experimental and should not yet be used for production jobs -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.10 +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py311 b/ci/official/envs/continuous_linux_arm64_cpu_py311 index 7a0ae9e84a1134..410fecc1d7be39 100644 --- a/ci/official/envs/continuous_linux_arm64_cpu_py311 +++ b/ci/official/envs/continuous_linux_arm64_cpu_py311 @@ -1,7 +1,6 @@ # This envrionment is experimental and should not yet be used for production jobs -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.11 +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py311_cross_compile b/ci/official/envs/continuous_linux_arm64_cpu_py311_cross_compile new file mode 100644 index 00000000000000..d506aca9441b98 --- /dev/null +++ b/ci/official/envs/continuous_linux_arm64_cpu_py311_cross_compile @@ -0,0 +1,6 @@ +# This envrionment is experimental and should not yet be used for production jobs +TFCI_BAZEL_COMMON_ARGS="--config rbe_cross_compile_linux_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 +TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py39 b/ci/official/envs/continuous_linux_arm64_cpu_py39 index 53aee870f9c66a..7b98c0b838d000 100644 --- a/ci/official/envs/continuous_linux_arm64_cpu_py39 +++ b/ci/official/envs/continuous_linux_arm64_cpu_py39 @@ -1,7 +1,6 @@ # This envrionment is experimental and should not yet be used for production jobs -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.9 +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile b/ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile new file mode 100644 index 00000000000000..23870d6c181bd3 --- /dev/null +++ b/ci/official/envs/continuous_linux_arm64_cpu_py39_cross_compile @@ -0,0 +1,6 @@ +# This envrionment is experimental and should not yet be used for production jobs +TFCI_BAZEL_COMMON_ARGS="--config rbe_cross_compile_linux_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 +TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/continuous_linux_x86_cpu_py310 b/ci/official/envs/continuous_linux_x86_cpu_py310 index 13b2730a609d4b..5297dd60604781 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py310 +++ b/ci/official/envs/continuous_linux_x86_cpu_py310 @@ -1,6 +1,5 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.10 +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_DOCKER_IMAGE=tensorflow/build:latest-pythonlatest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/continuous_linux_x86_cpu_py311 b/ci/official/envs/continuous_linux_x86_cpu_py311 index 3f92c5c2513257..4a306e19f97258 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py311 +++ b/ci/official/envs/continuous_linux_x86_cpu_py311 @@ -1,6 +1,5 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.11 +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_linux_x86_cpu_py39 b/ci/official/envs/continuous_linux_x86_cpu_py39 index 4ca275cf32a943..6b225c4e8f3170 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py39 +++ b/ci/official/envs/continuous_linux_x86_cpu_py39 @@ -1,6 +1,5 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.9 +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py310 b/ci/official/envs/continuous_linux_x86_cuda_py310 index f09a5d55110948..95e30867ced0ed 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py310 +++ b/ci/official/envs/continuous_linux_x86_cuda_py310 @@ -1,8 +1,7 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.10 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py311 b/ci/official/envs/continuous_linux_x86_cuda_py311 index cd834c2acfbde1..8bc69dc0ed514c 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py311 +++ b/ci/official/envs/continuous_linux_x86_cuda_py311 @@ -1,8 +1,7 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.11 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py39 b/ci/official/envs/continuous_linux_x86_cuda_py39 index 798dfdf25109d4..3899fed43065ba 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py39 +++ b/ci/official/envs/continuous_linux_x86_cuda_py39 @@ -1,8 +1,7 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.9 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config rbe_linux_cuda --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/continuous_macos_arm64_py310 b/ci/official/envs/continuous_macos_arm64_py310 index a08a3350534751..81e98e74ea4c80 100644 --- a/ci/official/envs/continuous_macos_arm64_py310 +++ b/ci/official/envs/continuous_macos_arm64_py310 @@ -1,5 +1,6 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.10 +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_ENABLE=0 +TFCI_PYTHON_VERSION=3.10 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/continuous_macos_arm64_py311 b/ci/official/envs/continuous_macos_arm64_py311 index 230d18d7c2b2f6..f4e7ce7120a858 100644 --- a/ci/official/envs/continuous_macos_arm64_py311 +++ b/ci/official/envs/continuous_macos_arm64_py311 @@ -1,5 +1,6 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.11 +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_ENABLE=0 +TFCI_PYTHON_VERSION=3.11 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/continuous_macos_arm64_py39 b/ci/official/envs/continuous_macos_arm64_py39 index 59585ff1b37857..66ca0b11dfb918 100644 --- a/ci/official/envs/continuous_macos_arm64_py39 +++ b/ci/official/envs/continuous_macos_arm64_py39 @@ -1,5 +1,6 @@ -source ci/official/envs/ci_default -TFCI_PYTHON_VERSION=3.9 +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) TFCI_DOCKER_ENABLE=0 +TFCI_PYTHON_VERSION=3.9 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu index 9fbd23ae501601..d5e7b0b634f0ef 100644 --- a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu +++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu @@ -1,8 +1,7 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cuda b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda index 0b35c0e67f78ee..adb557c7845196 100644 --- a/ci/official/envs/nightly_libtensorflow_linux_x86_cuda +++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda @@ -1,10 +1,9 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_LIB_SUFFIX="-gpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_libtensorflow_macos_arm64 b/ci/official/envs/nightly_libtensorflow_macos_arm64 index d29447dc50415c..195563aaa1f79c 100644 --- a/ci/official/envs/nightly_libtensorflow_macos_arm64 +++ b/ci/official/envs/nightly_libtensorflow_macos_arm64 @@ -1,8 +1,7 @@ -source ci/official/envs/ci_default -source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +# Disable arm64 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_DOCKER_ENABLE=0 -TFCI_LIB_SUFFIX="-cpu-macos-arm64" +TFCI_LIB_SUFFIX="-cpu-darwin-arm64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_URI=1 \ No newline at end of file +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_libtensorflow_macos_x86 b/ci/official/envs/nightly_libtensorflow_macos_x86 new file mode 100644 index 00000000000000..113111468bfb67 --- /dev/null +++ b/ci/official/envs/nightly_libtensorflow_macos_x86 @@ -0,0 +1,7 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_DOCKER_ENABLE=0 +TFCI_LIB_SUFFIX="-cpu-darwin-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py310 b/ci/official/envs/nightly_linux_arm64_cpu_py310 index 5b7900c43423b2..99abd33e228d06 100644 --- a/ci/official/envs/nightly_linux_arm64_cpu_py310 +++ b/ci/official/envs/nightly_linux_arm64_cpu_py310 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default # Disable arm64 uploads while being worked on source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_PLAT=manylinux2014_aarch64 +TFCI_WHL_SIZE_LIMIT_ENABLE= diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py311 b/ci/official/envs/nightly_linux_arm64_cpu_py311 index 6edb93ba0bdf73..5ce6b38552bee4 100644 --- a/ci/official/envs/nightly_linux_arm64_cpu_py311 +++ b/ci/official/envs/nightly_linux_arm64_cpu_py311 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default # Disable arm64 uploads while being worked on source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.11 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_PLAT=manylinux2014_aarch64 +TFCI_WHL_SIZE_LIMIT_ENABLE= diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py312 b/ci/official/envs/nightly_linux_arm64_cpu_py312 index dfe96fafb5568e..59ac34a405b3cb 100644 --- a/ci/official/envs/nightly_linux_arm64_cpu_py312 +++ b/ci/official/envs/nightly_linux_arm64_cpu_py312 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default # Disable arm64 uploads while being worked on source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.12 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_PLAT=manylinux2014_aarch64 +TFCI_WHL_SIZE_LIMIT_ENABLE= diff --git a/ci/official/envs/nightly_linux_arm64_cpu_py39 b/ci/official/envs/nightly_linux_arm64_cpu_py39 index e3b516111fdc85..e707083e020661 100644 --- a/ci/official/envs/nightly_linux_arm64_cpu_py39 +++ b/ci/official/envs/nightly_linux_arm64_cpu_py39 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default # Disable arm64 uploads while being worked on source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.9 -TFCI_BAZEL_COMMON_ARGS=(--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_arm64_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python -TFCI_DOCKER_REBUILD_ARGS=(--target=tf ci/official/containers/linux_arm64) +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_PLAT=manylinux2014_aarch64 +TFCI_WHL_SIZE_LIMIT_ENABLE= diff --git a/ci/official/envs/nightly_linux_x86_cpu_py310 b/ci/official/envs/nightly_linux_x86_cpu_py310 index 574ac7bee1f004..6576b8ab239593 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py310 +++ b/ci/official/envs/nightly_linux_x86_cpu_py310 @@ -1,9 +1,10 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=240M diff --git a/ci/official/envs/nightly_linux_x86_cpu_py311 b/ci/official/envs/nightly_linux_x86_cpu_py311 index d1b8bfea93cc74..544fff21a905fd 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py311 +++ b/ci/official/envs/nightly_linux_x86_cpu_py311 @@ -1,9 +1,10 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.11 -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=240M diff --git a/ci/official/envs/nightly_linux_x86_cpu_py312 b/ci/official/envs/nightly_linux_x86_cpu_py312 index 586fd92e5d703c..b8442d9e03cb4a 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py312 +++ b/ci/official/envs/nightly_linux_x86_cpu_py312 @@ -1,10 +1,10 @@ -source ci/official/envs/ci_default -# Disable 3.12 uploads while being worked on -source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.12 -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +source ci/official/envs/ci_nightly_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=240M diff --git a/ci/official/envs/nightly_linux_x86_cpu_py39 b/ci/official/envs/nightly_linux_x86_cpu_py39 index 2c3e1183a37171..69696ee814f77e 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py39 +++ b/ci/official/envs/nightly_linux_x86_cpu_py39 @@ -1,9 +1,10 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.9 -TFCI_BAZEL_COMMON_ARGS=(--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=240M diff --git a/ci/official/envs/nightly_linux_x86_cuda_py310 b/ci/official/envs/nightly_linux_x86_cuda_py310 index 16038d62bd646d..ec26fb1cb14905 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py310 +++ b/ci/official/envs/nightly_linux_x86_cuda_py310 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_cuda_py311 b/ci/official/envs/nightly_linux_x86_cuda_py311 index 1d0d931477a686..e7101efa94cb57 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py311 +++ b/ci/official/envs/nightly_linux_x86_cuda_py311 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.11 -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_cuda_py312 b/ci/official/envs/nightly_linux_x86_cuda_py312 index 4767f6dbdd6483..4b9e371ae26ed3 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py312 +++ b/ci/official/envs/nightly_linux_x86_cuda_py312 @@ -1,11 +1,11 @@ -source ci/official/envs/ci_default -# Disable 3.12 uploads while being worked on -source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.12 -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +source ci/official/envs/ci_nightly_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 \ No newline at end of file +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_cuda_py39 b/ci/official/envs/nightly_linux_x86_cuda_py39 index e3a5d3f8c8d1a8..63ee868a8db0b3 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py39 +++ b/ci/official/envs/nightly_linux_x86_cuda_py39 @@ -1,10 +1,11 @@ -source ci/official/envs/ci_default source ci/official/envs/ci_nightly_uploads -TFCI_PYTHON_VERSION=3.9 -TFCI_BAZEL_COMMON_ARGS=(--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) +TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda -TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag) -TFCI_DOCKER_ARGS=(--gpus all) +TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" +TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py310 b/ci/official/envs/nightly_linux_x86_tpu_py310 index 4e8014120f3762..8367da6b55b456 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py310 +++ b/ci/official/envs/nightly_linux_x86_tpu_py310 @@ -1,11 +1,13 @@ -source ci/official/envs/ci_default # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.10 -TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu) +source ci/official/envs/ci_nightly_uploads +TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py311 b/ci/official/envs/nightly_linux_x86_tpu_py311 index e4ae8cccf4fd46..8a186aad7dcce0 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py311 +++ b/ci/official/envs/nightly_linux_x86_tpu_py311 @@ -1,11 +1,13 @@ -source ci/official/envs/ci_default # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.11 -TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu) +source ci/official/envs/ci_nightly_uploads +TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py312 b/ci/official/envs/nightly_linux_x86_tpu_py312 index 54d96b16548a4a..0f8c73bd601e26 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py312 +++ b/ci/official/envs/nightly_linux_x86_tpu_py312 @@ -1,11 +1,13 @@ -source ci/official/envs/ci_default # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.12 -TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu) +source ci/official/envs/ci_nightly_uploads +TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py39 b/ci/official/envs/nightly_linux_x86_tpu_py39 index 4adaa8b216fbba..aa413f939ee5fd 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py39 +++ b/ci/official/envs/nightly_linux_x86_tpu_py39 @@ -1,11 +1,13 @@ -source ci/official/envs/ci_default # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.9 -TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=tpu) +source ci/official/envs/ci_nightly_uploads +TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu -TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 +TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_macos_arm64_py310 b/ci/official/envs/nightly_macos_arm64_py310 index 81fa2c977d6944..6c007ce1c318d7 100644 --- a/ci/official/envs/nightly_macos_arm64_py310 +++ b/ci/official/envs/nightly_macos_arm64_py310 @@ -1,9 +1,12 @@ -source ci/official/envs/ci_default source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.10 +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=240M +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py311 b/ci/official/envs/nightly_macos_arm64_py311 index e8046a3b5951b1..a3dfd672843273 100644 --- a/ci/official/envs/nightly_macos_arm64_py311 +++ b/ci/official/envs/nightly_macos_arm64_py311 @@ -1,9 +1,11 @@ -source ci/official/envs/ci_default source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.11 +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=240M +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py312 b/ci/official/envs/nightly_macos_arm64_py312 index 21432f076f6283..3da9c1040956da 100644 --- a/ci/official/envs/nightly_macos_arm64_py312 +++ b/ci/official/envs/nightly_macos_arm64_py312 @@ -1,9 +1,12 @@ -source ci/official/envs/ci_default source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.12 +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=240M +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_arm64_py39 b/ci/official/envs/nightly_macos_arm64_py39 index ee58e84c6624ca..36682a1e08421b 100644 --- a/ci/official/envs/nightly_macos_arm64_py39 +++ b/ci/official/envs/nightly_macos_arm64_py39 @@ -1,9 +1,14 @@ -source ci/official/envs/ci_default source ci/official/envs/disable_all_uploads -TFCI_PYTHON_VERSION=3.9 +# TODO(srnitin): Add resultstore config once the macOS builds have the right +# permissions +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION) -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=240M +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_x86_py310 b/ci/official/envs/nightly_macos_x86_py310 new file mode 100644 index 00000000000000..9577841dea84ec --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py310 @@ -0,0 +1,16 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.10 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_x86_py311 b/ci/official/envs/nightly_macos_x86_py311 new file mode 100644 index 00000000000000..4fe9bad43f89f6 --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py311 @@ -0,0 +1,16 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.11 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" diff --git a/ci/official/envs/nightly_macos_x86_py312 b/ci/official/envs/nightly_macos_x86_py312 new file mode 100644 index 00000000000000..a4397de120d90c --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py312 @@ -0,0 +1,16 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.12 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 +TFCI_MACOS_PYENV_INSTALL_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" \ No newline at end of file diff --git a/ci/official/envs/nightly_macos_x86_py39 b/ci/official/envs/nightly_macos_x86_py39 new file mode 100644 index 00000000000000..58c570c5d10507 --- /dev/null +++ b/ci/official/envs/nightly_macos_x86_py39 @@ -0,0 +1,14 @@ +# Disable macOS x86 uploads while being worked on +source ci/official/envs/disable_all_uploads +TFCI_BAZEL_COMMON_ARGS="--config release_macos_x86 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_x86 +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" +TFCI_DOCKER_ENABLE=0 +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_PYTHON_VERSION=3.9 +TFCI_WHL_AUDIT_ENABLE= +TFCI_WHL_SIZE_LIMIT=255M +TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 +TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" +TFCI_MACOS_BAZEL_TEST_DIR_ENABLE=1 +TFCI_MACOS_BAZEL_TEST_DIR_PATH="/Volumes/BuildData/bazel_output" \ No newline at end of file diff --git a/ci/official/envs/sample b/ci/official/envs/sample index 1e01d6ae93b877..e7717e0b25fcae 100644 --- a/ci/official/envs/sample +++ b/ci/official/envs/sample @@ -16,7 +16,7 @@ set +u; source ci/official/envs/your_choice_here; set -u # different Python versions. You can add e.g. "--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" # to change the Python version to anything available (including the default) in # tensorflow/tools/toolchains/python/python_repo.bzl. -TFCI_BAZEL_COMMON_ARGS=(--config tf_public_cache --disk_cache=build_output/cache) +TFCI_BAZEL_COMMON_ARGS='--config tf_public_cache --disk_cache=build_output/cache' # Disable all CI-specific behavior. You never need any of these if you are # running a script locally. diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh index 402de63ebc97a2..e6b8ff4dd865b3 100755 --- a/ci/official/libtensorflow.sh +++ b/ci/official/libtensorflow.sh @@ -25,8 +25,8 @@ if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly fi -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=linux_libtensorflow_test -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" --config=linux_libtensorflow_build +tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS --config=linux_libtensorflow_test +tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS --config=linux_libtensorflow_build tfrun ./ci/official/utilities/repack_libtensorflow.sh "$TFCI_OUTPUT_DIR" "$TFCI_LIB_SUFFIX" diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 6a4bd8821bbefb..3c83fd5772e1b8 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -15,7 +15,7 @@ # ============================================================================== source "${BASH_SOURCE%/*}/utilities/setup.sh" -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" # Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling diff --git a/ci/official/utilities/cleanup_summary.sh b/ci/official/utilities/cleanup_summary.sh index 046e1d79014953..dbe2203fa130af 100755 --- a/ci/official/utilities/cleanup_summary.sh +++ b/ci/official/utilities/cleanup_summary.sh @@ -14,6 +14,8 @@ # limitations under the License. # ============================================================================== +set -euxo pipefail + function resultstore_extract_fallback { # In case the main script fails somehow. cat < $BATS_TEST_TMPDIR/pip_deps + bazel cquery --keep_going 'deps(//tensorflow/tools/pip_package:build_pip_package)' | sort -u > $BATS_TEST_TMPDIR/pip_deps # Find all Python py_test targets not tagged "no_pip" or "manual", excluding # any targets in ignored packages. Combine this list of targets into a bazel # query list (e.g. the list becomes "target+target2+target3") - bazel query --keep_going 'kind(py_test, //tensorflow/python/...) - attr("tags", "no_pip|manual", //tensorflow/python/...)' | grep -v -f $BATS_TEST_TMPDIR/ignore_deps_for_these_packages | paste -sd "+" - > $BATS_TEST_TMPDIR/deps + bazel cquery --keep_going 'kind(py_test, //tensorflow/python/...) - attr("tags", "no_pip|manual", //tensorflow/python/...)' | grep -v -f $BATS_TEST_TMPDIR/ignore_deps_for_these_packages | paste -sd "+" - > $BATS_TEST_TMPDIR/deps # Find all one-step dependencies of those tests which are from //tensorflow # (since external deps will come from Python-level pip dependencies), # excluding dependencies and files that are known to be unneccessary. # This creates a list of targets under //tensorflow that are required for # TensorFlow python tests. - bazel query --keep_going "deps($(cat $BATS_TEST_TMPDIR/deps), 1)" | grep "^//tensorflow" | grep -v -f $BATS_TEST_TMPDIR/ignore_these_deps | sort -u > $BATS_TEST_TMPDIR/required_deps + bazel cquery --keep_going "deps($(cat $BATS_TEST_TMPDIR/deps), 1)" | grep "^//tensorflow" | grep -v -f $BATS_TEST_TMPDIR/ignore_these_deps | sort -u > $BATS_TEST_TMPDIR/required_deps # Find if any required dependencies are missing from the list of dependencies @@ -203,7 +204,7 @@ EOF # For every missing dependency, find the tests which directly depend on # it, and print that list for debugging. Not really clear if this is # helpful since the only examples I've seen are enormous. - bazel query "rdeps(kind(py_test, $(cat $BATS_TEST_TMPDIR/deps)), $dep, 1)" + bazel cquery "rdeps(kind(py_test, $(cat $BATS_TEST_TMPDIR/deps)), $dep, 1)" done < $BATS_TEST_TMPDIR/missing_deps exit 1 fi diff --git a/ci/official/utilities/docker.sh b/ci/official/utilities/docker.sh index ea1ecc267a4fe8..c50ea618cfea6c 100755 --- a/ci/official/utilities/docker.sh +++ b/ci/official/utilities/docker.sh @@ -18,7 +18,7 @@ if [[ "$TFCI_DOCKER_PULL_ENABLE" == 1 ]]; then fi if [[ "$TFCI_DOCKER_REBUILD_ENABLE" == 1 ]]; then - DOCKER_BUILDKIT=1 docker build --cache-from "$TFCI_DOCKER_IMAGE" -t "$TFCI_DOCKER_IMAGE" "${TFCI_DOCKER_REBUILD_ARGS[@]}" + DOCKER_BUILDKIT=1 docker build --cache-from "$TFCI_DOCKER_IMAGE" -t "$TFCI_DOCKER_IMAGE" $TFCI_DOCKER_REBUILD_ARGS if [[ "$TFCI_DOCKER_REBUILD_UPLOAD_ENABLE" == 1 ]]; then docker push "$TFCI_DOCKER_IMAGE" fi @@ -28,9 +28,12 @@ fi # The container is not cleaned up automatically! Remove it with: # docker rm tf if ! docker container inspect tf >/dev/null 2>&1 ; then - docker run "${TFCI_DOCKER_ARGS[@]}" --name tf -w "$TFCI_GIT_DIR" -itd --rm \ + # Pass all existing TFCI_ variables into the Docker container + env_file=$(mktemp) + env | grep ^TFCI_ > "$env_file" + docker run $TFCI_DOCKER_ARGS --name tf -w "$TFCI_GIT_DIR" -itd --rm \ -v "$TFCI_GIT_DIR:$TFCI_GIT_DIR" \ - --env TFCI_PYTHON_VERSION \ + --env-file "$env_file" \ "$TFCI_DOCKER_IMAGE" \ bash fi diff --git a/ci/official/utilities/extract_resultstore_links.py b/ci/official/utilities/extract_resultstore_links.py index a8013974f20e56..da04f5473c505b 100644 --- a/ci/official/utilities/extract_resultstore_links.py +++ b/ci/official/utilities/extract_resultstore_links.py @@ -248,11 +248,12 @@ def create_xml_file(result_store_dict: ResultDictType, f.write(b'\n') tree.write(f) if verbose: - print(f'Wrote to {file_path}') + print(f'\nWrote XML with Bazel invocation results to {file_path}') def print_invocation_results(result_store_dict: ResultDictType): """Prints out a short summary of the found ResultStore links (if any).""" + print() if not result_store_dict: print('Found no ResultStore links for Bazel build/test invocations.') else: diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 4388329ae6edd7..5d02a96f7de7a9 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -15,28 +15,51 @@ # limitations under the License. # ============================================================================== # -# Check and rename wheels with auditwheel. Inserts the platform tags like -# "manylinux_xyz" into the wheel filename. +# Usage: rename_and_verify_wheels.sh +# This script is aware of TFCI_ variables, so it doesn't need any arguments. +# Puts new wheel through auditwheel to rename and verify it, deletes the old +# one, checks the filesize, and then ensures the new wheel is installable. set -euxo pipefail -DIR=$1 -find "$DIR" -iname "*.whl" | while read wheel; do - echo "Checking and renaming $wheel..." - wheel=$(realpath "$wheel") - # Repair wheel based upon name/architecture, fallback to x86 - if [[ $wheel == *"aarch64.whl" ]]; then - time python3 -m auditwheel repair --plat manylinux2014_aarch64 "$wheel" --wheel-dir "$DIR" 2>&1 | tee check.txt - else - time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir "$DIR" 2>&1 | tee check.txt - fi +cd "$TFCI_OUTPUT_DIR" - # We don't need the original wheel if it was renamed - new_wheel=$(awk '/Fixed-up wheel written to/ {print $NF}' check.txt) - if [[ "$new_wheel" != "$wheel" ]]; then - rm "$wheel" - wheel="$new_wheel" - fi - rm check.txt +if [[ "$(ls *.whl | wc -l | tr -d ' ')" != "1" ]]; then + echo "Error: $TFCI_OUTPUT_DIR should contain exactly one .whl file." + exit 1 +fi - TF_WHEEL="$wheel" BUILD_DIR="$DIR" bats ./ci/official/utilities/wheel_verification.bats --timing -done +# Repair wheels with auditwheel and delete the old one. +if [[ "$TFCI_WHL_AUDIT_ENABLE" == "1" ]]; then + python3 -m auditwheel repair --plat "$TFCI_WHL_AUDIT_PLAT" --wheel-dir . *.whl + # if the wheel is already named correctly, auditwheel won't rename it. so we + # list all .whl files by their modification time (ls -t) and delete anything + # other than the most recently-modified one (the new one). + ls -t *.whl | tail -n +2 | xargs rm +fi + +# Check if size is too big. TFCI_WHL_SIZE_LIMIT is in find's format, which can be +# 'k' for kilobytes, 'M' for megabytes, or 'G' for gigabytes, and the + to indicate +# "anything greater than" is added by the script. +if [[ "$TFCI_WHL_SIZE_LIMIT_ENABLE" == "1" ]] && [[ -n "$(find . -iname "*.whl" -size "+$TFCI_WHL_SIZE_LIMIT")" ]]; then + echo "Error: Generated wheel is too big! Limit is $TFCI_WHL_SIZE_LIMIT" + echo '(search for TFCI_WHL_SIZE_LIMIT to change it)' + ls -sh *.whl + exit 2 +fi + +# Quick install checks +venv=$(mktemp -d) +"python${TFCI_PYTHON_VERSION}" -m venv "$venv" +python="$venv/bin/python3" +"$python" -m pip install *.whl $TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS +"$python" -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' +"$python" -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)' +# VERY basic check to ensure the [and-cuda] package variant is installable. +# Checks TFCI_BAZEL_COMMON_ARGS for "gpu" or "cuda", implying that the test is +# relevant. All of the GPU test machines have CUDA installed via other means, +# so I am not sure how to verify that the dependencies themselves are valid for +# the moment. +if [[ "$TFCI_BAZEL_COMMON_ARGS" =~ gpu|cuda ]]; then + echo "Checking to make sure tensorflow[and-cuda] is installable..." + "$python" -m pip install "$(echo *.whl)[and-cuda]" $TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS +fi diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index aa4e7838ba0b8d..8c004aee8b8141 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -51,11 +51,18 @@ cd "$TFCI_GIT_DIR" # even works for arrays; e.g. TFCI_SOME_ARRAY="(--array --contents)" ends up # as TFCI_SOME_ARRAY=(--array --contents) in the storage file and is thus # loaded as an array when sourced. -if [[ -n "${TFCI:-}" ]]; then +if [[ -z "${TFCI:-}" ]]; then + echo '==TFCI==: The $TFCI variable is not set. This is fine as long as you' + echo 'already sourced a TFCI env file with "set -a; source ; set +a".' + echo 'If you have not, you will see a lot of undefined variable errors.' +else FROM_ENV=$(mktemp) # Piping into cat means grep won't abort the process if no errors are found. env | grep TFCI_ | cat > "$FROM_ENV" + # Source the default ci values + source ./ci/official/envs/ci_default + # Sourcing TFCI twice, the first time with "-u" unset, means that variable # order does not matter. i.e. "TFCI_BAR=$TFCI_FOO; TFCI_FOO=true" will work. # TFCI_FOO is only valid the second time through. @@ -73,10 +80,11 @@ if [[ -n "${TFCI:-}" ]]; then source "$FROM_ENV" rm "$FROM_ENV" fi -else - echo '==TFCI==: The $TFCI variable is not set. This is fine as long as you' - echo 'already sourced a TFCI env file with "set -a; source ; set +a".' - echo 'If you have not, you will see a lot of undefined variable errors.' +fi + +# Mac builds have some specific setup needs. See setup_macos.sh for details +if [[ "${OSTYPE}" =~ darwin* ]]; then + source ./ci/official/utilities/setup_macos.sh fi # Force-disable uploads if the job initiator is not Kokoro diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh new file mode 100644 index 00000000000000..a6bd223402490f --- /dev/null +++ b/ci/official/utilities/setup_macos.sh @@ -0,0 +1,94 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# macOS specific setup for all TF scripts. +# + +# Mac version of Core utilities differ in usage. Since our scripts are written +# with the GNU style, we need to set GNU utilities to be default on Mac. +if [[ -n "$(which grealpath)" ]] && [[ -n "$(which gstat)" ]]; then + alias realpath=grealpath + alias stat=gstat + # By default, aliases are only expanded in interactive shells, which means + # that they are not substituted for their corresponding commands in shell + # scripts. By setting "expand_aliases", we enable alias expansion in + # non-interactive shells as well. + shopt -s expand_aliases +else + echo '==TFCI==: Error: Cannot find path to grealpath or gstat' + echo 'TF CI scripts require GNU core utilties to be installed. Please make' + echo 'sure they are present on your system and try again.' + exit 1 +fi + +# "TFCI_MACOS_BAZEL_TEST_DIR_PATH" specifies the directory that Bazel should use +# when running tests. Each test will be executed in a separate subdirectory +# inside this directory. TF Mac builds need ~150 GB of disk space to be able to +# run all the tests. Since TFCI Mac VMs execute Bazel test commands in a +# partition with insufficient storage, we specify the +# 'TFCI_MACOS_BAZEL_TEST_DIR_PATH' environment variable to point to a partition +# with ample storage. When this variable is empty (i.e by default), Bazel will +# use the output base directory to run tests. +if [[ "${TFCI_MACOS_BAZEL_TEST_DIR_ENABLE}" == 1 ]]; then + mkdir -p "${TFCI_MACOS_BAZEL_TEST_DIR_PATH}" + export TEST_TMPDIR="${TFCI_MACOS_BAZEL_TEST_DIR_PATH}" +fi + +# "TFCI_MACOS_INSTALL_BAZELISK_ENABLE" is used to decide if we need to install +# Bazelisk manually. We enable this for macOS x86 builds as those VMs do not +# have Bazelisk pre-installed. "TFCI_MACOS_INSTALL_BAZELISK_URL" contains the +# link to the Bazelisk binary which needs to be downloaded. +if [[ "${TFCI_MACOS_INSTALL_BAZELISK_ENABLE}" == 1 ]]; then + sudo wget --no-verbose -O "/usr/local/bin/bazel" "${TFCI_MACOS_INSTALL_BAZELISK_URL}" + chmod +x "/usr/local/bin/bazel" +fi + +# "TFCI_MACOS_UPGRADE_PYENV_ENABLE" is used to decide if we need to upgrade the +# Pyenv version. We enable this for macOS x86 builds as the default Pyenv on +# those VMs does not support installing Python 3.10 and above which we need +# for running smoke tests in nightly/release wheel builds. +if [[ "${TFCI_MACOS_UPGRADE_PYENV_ENABLE}" == 1 ]]; then + brew upgrade pyenv +fi + +# "TFCI_MACOS_PYENV_INSTALL_ENABLE" controls whether to use Pyenv to install +# the Python version set in "TFCI_PYTHON_VERSION" and use it as default. +# We enable this in the nightly and release builds because before uploading the +# wheels, we install them in a virtual environment and run some smoke tests on +# it. TFCI Mac VMs only have one Python version installed so we need to install +# the other versions manually. +if [[ "${TFCI_MACOS_PYENV_INSTALL_ENABLE}" == 1 ]]; then + pyenv install "$TFCI_PYTHON_VERSION" + pyenv local "$TFCI_PYTHON_VERSION" + # Do a sanity check to make sure that we using the correct Python version + python --version +fi + +if [[ "$TFCI_PYTHON_VERSION" == "3.12" ]]; then + # dm-tree (Keras v3 dependency) doesn't have pre-built wheels for 3.12 yet. + # Having CMake allows building them. + # Once the wheels are added, this should be removed - b/308399490. + brew install cmake +fi + +# Scheduled nightly and release builds upload build artifacts (Pip packages, +# Libtensorflow archives) to GCS buckets. TFCI Mac VMs need to authenticate as +# a service account that has the right permissions to be able to do so. +set +x +if [[ -n "${GOOGLE_APPLICATION_CREDENTIALS:-}" ]]; then + gcloud auth activate-service-account --key-file="${GOOGLE_APPLICATION_CREDENTIALS}" +fi +set -x \ No newline at end of file diff --git a/ci/official/utilities/wheel_verification.bats b/ci/official/utilities/wheel_verification.bats deleted file mode 100644 index 99d0f32e35162e..00000000000000 --- a/ci/official/utilities/wheel_verification.bats +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Suite of verification tests for the SINGLE TensorFlow wheel in the -# $BUILD_DIR directory, or whatever path is set as $TF_WHEEL. - -setup_file() { - cd "$BUILD_DIR" - if [[ -z "$TF_WHEEL" ]]; then - export TF_WHEEL=$(find "$BUILD_DIR" -iname "*.whl") - fi - - # Setup the env for the python import testing - if [[ $TF_WHEEL == *"aarch64.whl" ]]; then - python${TFCI_PYTHON_VERSION} -m venv "$BATS_FILE_TMPDIR/venv" - else - python3 -m venv "$BATS_FILE_TMPDIR/venv" - fi -} - -teardown_file() { - rm -rf "$BATS_FILE_TMPDIR/venv" -} - -@test "Wheel is manylinux2014 (manylinux_2_17) compliant" { - python3 -m auditwheel show "$TF_WHEEL" > audit.txt - # Verify wheel based upon name/architecture, fallback to x86 - if [[ $TF_WHEEL == *"aarch64.whl" ]]; then - grep --quiet -zoP 'is consistent with the following platform tag:\n"manylinux_2_17_aarch64"\.' audit.txt - else - grep --quiet 'This constrains the platform tag to "manylinux_2_17_x86_64"' audit.txt - fi -} - -@test "Wheel conforms to upstream size limitations" { - WHEEL_MEGABYTES=$(stat --format %s "$TF_WHEEL" | awk '{print int($1/(1024*1024))}') - # Googlers: search for "test_tf_whl_size" - case "$TF_WHEEL" in - # CPU: - *cpu*manylinux*) LARGEST_OK_SIZE=240 ;; - # GPU: - *manylinux*) LARGEST_OK_SIZE=580 ;; - # Unknown: - *) - echo "The wheel's name is in an unknown format." - exit 1 - ;; - esac - # >&3 forces output in bats even if the test passes. See - # https://bats-core.readthedocs.io/en/stable/writing-tests.html#printing-to-the-terminal - echo "# Size of $TF_WHEEL is $WHEEL_MEGABYTES / $LARGEST_OK_SIZE megabytes." >&3 - test "$WHEEL_MEGABYTES" -le "$LARGEST_OK_SIZE" -} - -# Note: this runs before the tests further down the file, so TF is installed in -# the venv and the venv is active when those tests run. The venv gets cleaned -# up in teardown_file() above. -@test "Wheel is installable" { - source "$BATS_FILE_TMPDIR/venv/bin/activate" - python3 -m pip install "$TF_WHEEL" -} - -@test "TensorFlow is importable" { - source "$BATS_FILE_TMPDIR/venv/bin/activate" - python3 -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' -} - -# Is this still useful? -@test "TensorFlow has Keras" { - source "$BATS_FILE_TMPDIR/venv/bin/activate" - python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)' -} - -# Is this still useful? -@test "TensorFlow has Estimator" { - source "$BATS_FILE_TMPDIR/venv/bin/activate" - python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.estimator" in tf.estimator.__name__ else 1)' -} diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index 20c6f2637d7e12..5789e58703a18b 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -25,30 +25,17 @@ if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly fi -# Download libtpu.so for tensorflow-tpu builds only. -if [[ "$TFCI_LIBTPU_DOWNLOAD_ENABLE" == 1 ]]; then - wget -P ./tensorflow/lib/ "$TFCI_LIBTPU_DOWNLOAD_URL" -fi -if [[ "$TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE" == 1 ]]; then - # For nightly jobs, libtpu.so comes from the latest nightly libtpu build. - # Note: expects a working wheel for today - DATE=$(TZ='America/Los_Angeles' date '+%Y%m%d') - tfrun wget "https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev${DATE}-py3-none-any.whl" -O libtpu.whl - # -j to discard intermediate directories; -o to overwrite if exists; -d to set output dir - tfrun unzip libtpu.whl libtpu/libtpu.so -j -o -d ./tensorflow/lib -fi - -tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package -tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_OUTPUT_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" -tfrun ./ci/official/utilities/rename_and_verify_wheels.sh "$TFCI_OUTPUT_DIR" +tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS //tensorflow/tools/pip_package:build_pip_package +tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_OUTPUT_DIR" $TFCI_BUILD_PIP_PACKAGE_ARGS +tfrun ./ci/official/utilities/rename_and_verify_wheels.sh if [[ "$TFCI_UPLOAD_WHL_PYPI_ENABLE" == 1 ]]; then - twine upload "${TFCI_UPLOAD_WHL_PYPI_ARGS[@]}" "$TFCI_OUTPUT_DIR"/*.whl + twine upload $TFCI_UPLOAD_WHL_PYPI_ARGS "$TFCI_OUTPUT_DIR"/*.whl fi if [[ "$TFCI_UPLOAD_WHL_GCS_ENABLE" == 1 ]]; then gsutil cp "$TFCI_OUTPUT_DIR"/*.whl "$TFCI_UPLOAD_WHL_GCS_URI" fi if [[ "$TFCI_WHL_BAZEL_TEST_ENABLE" == 1 ]]; then - tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" test "${TFCI_BAZEL_COMMON_ARGS[@]}" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_wheel_test" + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_wheel_test" fi diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ef01b603800a71..289f37ef902c63 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -41,8 +41,11 @@ load( ) # copybara:uncomment_begin +# # buildifier: disable=out-of-order-load +# load("//devtools/build_cleaner/skylark:action_config_test.bzl", "action_config_test") # load("//devtools/copybara/rules:copybara.bzl", "copybara_config_test") # load("//tools/build_defs/license:license.bzl", "license") +# # buildifier: enable=out-of-order-load # copybara:uncomment_end # copybara:comment_begin(oss-only) @@ -183,6 +186,11 @@ package( # ], # deps = [":copybara_config"], # ) +# +# action_config_test( +# name = "build_cleaner_spec_test", +# src = "build_cleaner_spec.textproto", +# ) # copybara:uncomment_end licenses(["notice"]) @@ -1366,7 +1374,6 @@ tf_cc_shared_library( "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model", - "//tensorflow/compiler/mlir/python:mlir", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", @@ -1449,9 +1456,14 @@ tf_cc_shared_library( "//tensorflow/lite:util", "//tensorflow/python/grappler:cost_analyzer_lib", "//tensorflow/tools/graph_transforms:transform_graph_lib", - ] + (tf_monitoring_python_deps() + - tf_additional_plugin_deps() + - tf_additional_profiler_deps()) + if_xla_available([ + ] + select({ + "//tensorflow/compiler/mlir/python:disable_mlir_config": [], + "//conditions:default": [ + "//tensorflow/compiler/mlir/python:mlir", + ], + }) + (tf_monitoring_python_deps() + + tf_additional_plugin_deps() + + tf_additional_profiler_deps()) + if_xla_available([ "//tensorflow/compiler/aot:tfcompile_lib", ]) + if_static(extra_deps = [ "//tensorflow/core/platform:tensor_float_32_utils", diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 321738084016a7..1ccf2fe07f0af9 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -30,7 +30,6 @@ import distutils as _distutils import importlib import inspect as _inspect -import logging as _logging import os as _os import site as _site import sys as _sys @@ -62,18 +61,6 @@ __path__.append(_tf_api_dir) # Hook external TensorFlow modules. -# Import compat before trying to import summary from tensorboard, so that -# reexport_tf_summary can get compat from sys.modules. Only needed if using -# lazy loading. -_current_module.compat.v2 # pylint: disable=pointless-statement -try: - from tensorboard.summary._tf import summary - _current_module.__path__ = ( - [_module_util.get_parent_dir(summary)] + _current_module.__path__) - setattr(_current_module, "summary", summary) -except ImportError: - _logging.warning( - "Limited tf.summary API due to missing TensorBoard installation.") # Load tensorflow-io-gcs-filesystem if enabled if (_os.getenv("TF_USE_MODULAR_FILESYSTEM", "0") == "true" or diff --git a/tensorflow/build_cleaner_spec.textproto b/tensorflow/build_cleaner_spec.textproto new file mode 100644 index 00000000000000..bea7e8ac36462a --- /dev/null +++ b/tensorflow/build_cleaner_spec.textproto @@ -0,0 +1,14 @@ +# proto-file: devtools/build_cleaner/proto/actions.proto +# proto-message: ActionSpecs + +# Python rules should not have more than one source file. +action_spec { + action: CHECK_FILE_COUNT + file_count_params { + rule_selector { + rule_kind_regex: "^.*py(type)?(_strict)?_(binary|library|test).*$" + generator_function_regex: "^(?!boq_header)$" + } + max_source_count: 1 + } +} \ No newline at end of file diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 5c7bbddc3af6f2..d6ad4fe8d5d244 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -16,7 +16,7 @@ cc_library( "//tensorflow/c:c_api_macros_hdrs", "//tensorflow/c:kernels_experimental_hdrs", "//tensorflow/c:kernels_hdrs", - "//tensorflow/c:tf_buffer_internal", + "//tensorflow/c:tf_buffer", "//tensorflow/c:tf_status_internal", "//tensorflow/c:tf_tensor_internal", "//tensorflow/compiler/jit:variable_info", diff --git a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden index 490514f80e18a4..54a45cb23ed110 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden +++ b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.cc.golden @@ -45,7 +45,7 @@ Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, AbstractTensorHa // Summary: // // Description: -Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a, bool transpose_b, const char* name, const char* raw_device_name) { +Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a, bool transpose_b, bool grad_a, bool grad_b, const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("MatMul", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -53,6 +53,8 @@ Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTenso TF_RETURN_IF_ERROR(op_ptr->AddInput(b)); TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("transpose_a", transpose_a)); TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("transpose_b", transpose_b)); + TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("grad_a", grad_a)); + TF_RETURN_IF_ERROR(op_ptr->SetAttrBool("grad_b", grad_b)); int num_retvals = 1; return op_ptr->Execute(absl::MakeSpan(product, 1), &num_retvals); } diff --git a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden index 4b24a4f55ecff1..1d1255a20d6aa9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden +++ b/tensorflow/c/experimental/ops/gen/cpp/golden/testing_ops.h.golden @@ -28,7 +28,7 @@ namespace ops { Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, AbstractTensorHandle** y, const char* name = nullptr, const char* raw_device_name = nullptr); // -Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a = false, bool transpose_b = false, const char* name = nullptr, const char* raw_device_name = nullptr); +Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, AbstractTensorHandle* const b, AbstractTensorHandle** product, bool transpose_a = false, bool transpose_b = false, bool grad_a = false, bool grad_b = false, const char* name = nullptr, const char* raw_device_name = nullptr); // Status IdentityN(AbstractContext* ctx, absl::Span input, absl::Span output, const char* name = nullptr, const char* raw_device_name = nullptr); diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 60f6d6e0250e75..af37ab0cb19011 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -13,7 +13,6 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - # copybara:uncomment() "//learning/brain/tfrt/aot:__pkg__", "//tensorflow/c:__subpackages__", "//tensorflow/c/experimental/saved_model/internal:__pkg__", ], diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 3fcd255a2248ab..12391143a4d9e0 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -230,10 +230,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { DeviceMemoryBase Allocate(uint64 size) { return Allocate(size, /*memory_space=*/0); } - void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, - uint64 size) override { - LOG(FATAL) << "GetSubBuffer is not supported by pluggable device."; - } void Deallocate(DeviceMemoryBase* mem) override { SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 7eda58471c2f57..0f3e2e76aa4ebe 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 09ebb300969aef..1d22fa18cba53a 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -249,7 +249,8 @@ class CAsyncOpKernel : public AsyncOpKernel { n.WaitForNotification(); } - void ComputeAsync(OpKernelContext* ctx, AsyncOpKernelDoneCallback done) { + void ComputeAsync(OpKernelContext* ctx, + AsyncOpKernelDoneCallback done) override { (*compute_async_func_)( c_kernel_, reinterpret_cast(ctx), reinterpret_cast(&done)); diff --git a/tensorflow/c/kernels_experimental.cc b/tensorflow/c/kernels_experimental.cc index 7e6f818be47b39..09ce84d42f7392 100644 --- a/tensorflow/c/kernels_experimental.cc +++ b/tensorflow/c/kernels_experimental.cc @@ -292,7 +292,7 @@ struct TmpVar : public ResourceBase { tensorflow::mutex mu; Tensor val; std::string name; - std::string DebugString() const { return name; } + std::string DebugString() const override { return name; } ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; } }; @@ -626,7 +626,7 @@ static Status CCBinaryAddFunc( binary_add_func(ctx, a, b, out); return cc_ctx->status(); } -}; +} static Status VariantBinaryAddFunc( ::tensorflow::OpKernelContext* cc_ctx, const Variant& a, const Variant& b, diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc index dcf0b5c5443187..d059c5d0c5729d 100644 --- a/tensorflow/cc/saved_model/bundle_v2.cc +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/bundle_v2.h" +#include #include #include @@ -113,8 +114,8 @@ absl::Status SavedModelV2Bundle::Load(const std::string& export_dir, // Load the variables checkpoint reader. const std::string variables_prefix = io::JoinPath(variables_dir, kSavedModelVariablesFilename); - bundle->variable_reader_.reset( - new BundleReader(Env::Default(), variables_prefix)); + bundle->variable_reader_ = + std::make_unique(Env::Default(), variables_prefix); TF_RETURN_WITH_CONTEXT_IF_ERROR( bundle->variable_reader_->status(), "Unable to load SavedModel variables checkpoint from ", diff --git a/tensorflow/cc/saved_model/image_format/BUILD b/tensorflow/cc/saved_model/image_format/BUILD index 10a35871a708be..7fd743cf9c8356 100644 --- a/tensorflow/cc/saved_model/image_format/BUILD +++ b/tensorflow/cc/saved_model/image_format/BUILD @@ -32,7 +32,9 @@ cc_library( "//tensorflow/tools/proto_splitter/cc:max_size", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", ] + if_not_windows_or_mac([ "//tensorflow/tools/proto_splitter:merge", "//tensorflow/tools/proto_splitter/cc:saved_model_splitter", diff --git a/tensorflow/cc/saved_model/image_format/internal_api.cc b/tensorflow/cc/saved_model/image_format/internal_api.cc index b959602ba445c9..db38d1786e59ea 100644 --- a/tensorflow/cc/saved_model/image_format/internal_api.cc +++ b/tensorflow/cc/saved_model/image_format/internal_api.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/cc/saved_model/image_format/internal_api.h" #include +#include #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "tensorflow/cc/saved_model/metrics.h" #include "tensorflow/cc/saved_model/util.h" @@ -31,7 +33,7 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/saved_model_splitter.h" #include "tensorflow/tools/proto_splitter/merge.h" #endif - +#define IS_OSS false namespace tensorflow { namespace image_format { @@ -104,6 +106,27 @@ absl::Status WriteSavedModel(SavedModel* saved_model_proto, #endif } +absl::StatusOr> WriteSavedModelToString( + SavedModel* saved_model_proto) { +#if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) + tools::proto_splitter::SavedModelSplitter splitter(saved_model_proto); + return splitter.WriteToString(); +#else + return absl::UnimplementedError( + "WriteSavedModelToString not implemented for Windows or MacOS."); +#endif +} + +#if !IS_OSS +// TODO(b/311769337): Define the function unconditionally after tf oss +// dependency is updated to protobuf v22.x. +absl::StatusOr> WriteSavedModelToCord( + SavedModel* saved_model_proto) { + tools::proto_splitter::SavedModelSplitter splitter(saved_model_proto); + return splitter.WriteToCord(); +} +#endif + absl::Status WriteSavedModel(SavedModel* saved_model_proto, const std::string& file_prefix, int debug_max_size) { diff --git a/tensorflow/cc/saved_model/image_format/internal_api.h b/tensorflow/cc/saved_model/image_format/internal_api.h index 465b00a74bfada..5c9b13d0f97364 100644 --- a/tensorflow/cc/saved_model/image_format/internal_api.h +++ b/tensorflow/cc/saved_model/image_format/internal_api.h @@ -17,10 +17,15 @@ limitations under the License. #define TENSORFLOW_CC_SAVED_MODEL_IMAGE_FORMAT_INTERNAL_API_H_ #include +#include #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "tensorflow/core/protobuf/saved_model.pb.h" +#define IS_OSS false + namespace tensorflow { namespace image_format { @@ -29,13 +34,24 @@ namespace image_format { absl::Status ReadSavedModel(const std::string& file_prefix, SavedModel* saved_model_proto); -// Writes the SavedModel proto to {file_prefix}{.pb|.cpb}. -// If the proto is < the protobuf maximum size, then it will be serialized -// as a `.pb` proto binary. When larger than the maximum size, the SavedModel -// proto is destructively separated into chunks and written to +// Writes the SavedModel proto to a file or to string. If the proto is < the +// protobuf maximum size, then it will be serialized as a `.pb` proto binary. +// When larger than the maximum size, the SavedModel proto is destructively +// separated into chunks and written to // `.cpb` (chunked proto). +// +// Write SavedModel to {file_prefix}{.pb|.cpb}. absl::Status WriteSavedModel(SavedModel* saved_model_proto, const std::string& file_prefix); +// Writes the SavedModel proto to std::string +// The bool field record whether it's saved as a chunked protobuf (true) or +// regular protobuf (false) +absl::StatusOr> WriteSavedModelToString( + SavedModel* saved_model_proto); +#if !IS_OSS +absl::StatusOr> WriteSavedModelToCord( + SavedModel* saved_model_proto); +#endif // See above. The `debug_max_size` argument can be used to the maximum size to // less than 2GB for testing purposes. diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 399d4cf37fef4c..a245bf59a1f187 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" +#include #include #include @@ -267,7 +268,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, } // namespace -SavedModelBundleInterface::~SavedModelBundleInterface() {} +SavedModelBundleInterface::~SavedModelBundleInterface() = default; Status LoadMetagraphIntoSession(const SessionOptions& session_options, const MetaGraphDef& meta_graph, @@ -491,7 +492,7 @@ Status LoadSavedModel(const SessionOptions& session_options, TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir, tags, &legacy_bundle)); *bundle = SavedModelBundleLite( - absl::make_unique(std::move(legacy_bundle.session)), + std::make_unique(std::move(legacy_bundle.session)), std::move(*legacy_bundle.meta_graph_def.mutable_signature_def())); return OkStatus(); } diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index f2d318a25b7274..1dcd951d92b5ed 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -121,7 +121,7 @@ Status LoadMetagraphIntoSession(const SessionOptions& session_options, Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, - SavedModelBundle* const bundle); + SavedModelBundle* bundle); /// Loads a SavedModel from the specified export directory. The MetaGraphDef /// to be loaded is identified by the supplied tags, corresponding exactly to @@ -133,7 +133,7 @@ Status LoadSavedModel(const SessionOptions& session_options, Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, - SavedModelBundleLite* const bundle); + SavedModelBundleLite* bundle); /// Checks whether the provided directory could contain a SavedModel. Note that /// the method does not load any data by itself. If the method returns `false`, diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index 9d2f954293eddc..701623c328081e 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -16,7 +16,6 @@ # pylint: disable=g-bad-import-order,g-import-not-at-top,protected-access -import logging as _logging import os as _os import sys as _sys import typing as _typing @@ -31,15 +30,6 @@ # Hook external TensorFlow modules. _current_module = _sys.modules[__name__] -try: - from tensorboard.summary._tf import summary - _current_module.__path__ = ( - [_module_util.get_parent_dir(summary)] + _current_module.__path__) - setattr(_current_module, "summary", summary) -except ImportError: - _logging.warning( - "Limited tf.compat.v2.summary API due to missing TensorBoard " - "installation.") # Lazy-load estimator. _estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator" diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 6c276dbedef1f2..92d62b34be8bf9 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -1,8 +1,8 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "filegroup", "genrule") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -131,7 +131,6 @@ genrule( tfcompile_test_dep_configs = [ ("", "None"), ("_mlir_bridge", "Bridge"), - ("_mhlo_lowering", "HloLowering"), ] [ @@ -473,42 +472,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "tfcompile_test_mhlo_lowering", - srcs = ["tfcompile_test.cc"], - extra_copts = ["-DMHLO_LOWERING_TEST"], - tags = [ - "manual", - "no_mac", # TODO(b/228273415) - ], - deps = [ - ":test_graph_tfadd_mhlo_lowering", - ":test_graph_tfadd_with_ckpt_mhlo_lowering", - ":test_graph_tfadd_with_ckpt_saver_mhlo_lowering", - ":test_graph_tfassert_eq_mhlo_lowering", - ":test_graph_tfcond_mhlo_lowering", - ":test_graph_tffunction_mhlo_lowering", - ":test_graph_tfgather_mhlo_lowering", - ":test_graph_tfmatmul_mhlo_lowering", - ":test_graph_tfmatmulandadd_mhlo_lowering", - ":test_graph_tfsplits_mhlo_lowering", - ":test_graph_tftop_k_mhlo_lowering", - ":test_graph_tfvariable_mhlo_lowering", - ":test_graph_tfvariable_readonly_mhlo_lowering", - ":test_graph_tfvariable_sequential_updates_mhlo_lowering", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:regexp", - "@com_google_absl//absl/strings", - "@eigen_archive//:eigen3", - "@local_xla//xla:shape_util", - "@local_xla//xla:test", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/service:hlo_profile_printer", - ], -) - tf_cc_test( name = "tfcompile_test_mlir_bridge", srcs = ["tfcompile_test.cc"], diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index f056533d1b21e6..a543aae5b92997 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -329,16 +329,7 @@ def _tf_library( "@local_xla//xla/service/cpu:runtime_single_threaded_conv2d", "@local_xla//xla/service/cpu:runtime_single_threaded_matmul", "@eigen_archive//:eigen3", - ] or []) + ( - mlir_components.count("HloLowering") > 0 and [ - "@local_xla//xla/runtime:aot_ffi_c_symbols", - "@local_xla//xla/service/cpu:runtime_mlir_utils", - ] or [] - ) + ( - include_standard_runtime_deps and mlir_components == "HloLowering" and [ - "@local_xla//xla/service/cpu/runtime:retain", - ] or [] - ) + (deps or []), + ] or []) + (deps or []), tags = tags, copts = copts, ) @@ -559,31 +550,6 @@ def tf_library( copts, xla_flags, ) - if mlir_components == "None": - _tf_library( - name + "_mlir", - graph, - config, - debug_info, - freeze_checkpoint, - freeze_saver, - cpp_class, - gen_test, - gen_benchmark, - gen_compiler_log, - visibility, - testonly, - tfcompile_flags, - tfcompile_tool, - include_standard_runtime_deps, - enable_xla_hlo_profiling, - enable_tracemes, - "HloLowering", - deps, - tags + ["notap", "local", "manual"], - copts, - xla_flags, - ) def target_llvm_triple(): """Returns the target LLVM triple to be used for compiling the target.""" diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 276ab1786a8260..82ed25767b90de 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -283,6 +283,7 @@ void AllocateAndParseFlags() { bool enable_mlir_merge_control_flow_pass = true; bool enable_mlir_convert_control_to_data_outputs_pass = false; bool enable_mlir_strict_clusters = false; + bool enable_mlir_multiple_local_cpu_devices = false; // Dump graphs in TFG dialect. bool use_tfg_graph_dumper = false; bool enable_mlir_generic_outside_compilation = false; @@ -377,6 +378,11 @@ void AllocateAndParseFlags() { "MLIR-Based TensorFlow Compiler Bridge."), Flag("tf_mlir_enable_strict_clusters", &enable_mlir_strict_clusters, "Do not allow clusters that have cyclic control dependencies."), + Flag("tf_mlir_enable_multiple_local_cpu_devices", + &enable_mlir_multiple_local_cpu_devices, + "Enable multiple local CPU devices. CPU ops which are outside " + "compiled inside the tpu cluster will also be replicated across " + "multiple cpu devices."), Flag("tf_dump_graphs_in_tfg", &use_tfg_graph_dumper, "When tf_dump_graphs_in_tfg is true, graphs after transformations " "are dumped in MLIR TFG dialect and not in GraphDef"), @@ -413,6 +419,8 @@ void AllocateAndParseFlags() { enable_mlir_generic_outside_compilation; mlir_flags->tf_mlir_enable_tpu_variable_runtime_reformatting_pass = enable_tpu_variable_runtime_reformatting_pass; + mlir_flags->tf_mlir_enable_multiple_local_cpu_devices = + enable_mlir_multiple_local_cpu_devices; if (use_tfg_graph_dumper) { UseMlirForGraphDump(MlirDumpConfig{}.elide_large_attributes().emit_dialect( diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 04a15136d43072..45a4c83a614afd 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -291,6 +291,9 @@ struct MlirCommonFlags { bool tf_mlir_enable_strict_clusters; bool tf_mlir_enable_generic_outside_compilation; bool tf_mlir_enable_tpu_variable_runtime_reformatting_pass; + // TODO(pineapplejuice233): Revisit this flag once the performance impact is verified + // with different local CPU devices settings. + bool tf_mlir_enable_multiple_local_cpu_devices; }; // Flags for the JitRt pipeline -- see tf_jitrt_pipeline.h for details. diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index f3272f81fe6182..02c9f486e8e000 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -61,6 +61,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tosa:tf_passes", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8117705b0fac2b..b6d406f040a296 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -319,9 +319,24 @@ cc_library( ], ) +gentbl_cc_library( + name = "tensorflow_lite_canonicalize_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "ir/tfl_canonicalize.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tfl_canonicalize.td", + deps = [":tensorflow_lite_patterns_td_files"], +) + cc_library( name = "tensorflow_lite", srcs = [ + "ir/tfl_canonicalize.inc", "ir/tfl_ops.cc", "ir/tfl_ops.cc.inc", "ir/tfl_ops.h.inc", @@ -343,8 +358,10 @@ cc_library( "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ + ":converter_inc", ":cost_estimators", ":size_utils", + ":tensorflow_lite_canonicalize_inc_gen", ":tensorflow_lite_op_enums_inc_gen", ":tensorflow_lite_op_interfaces_inc_gen", ":tensorflow_lite_ops_inc_gen", @@ -360,6 +377,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", @@ -1314,18 +1332,18 @@ cc_library( ":common", ":fake_quant_utils", ":tensorflow_lite_d2s", - ":tensorflow_lite_legalize_tf", - ":tensorflow_lite_optimize", - ":tensorflow_lite_optimize_batch_matmul", - ":tensorflow_lite_quantize", + ":tensorflow_lite_legalize_tf", # buildcleaner: keep + ":tensorflow_lite_optimize", # buildcleaner: keep + ":tensorflow_lite_optimize_batch_matmul", # buildcleaner: keep + ":tensorflow_lite_quantize", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", - "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", - "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", + "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep + "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo:uniform_quantized_stablehlo_to_tfl_pass", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 4f044e153c68bb..81f69bfa87e940 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -165,12 +165,15 @@ constexpr size_t kInitialBufferSize = 10240; // `isSigned` is set to false for other types. static StatusOr GetTFLiteType(Type type, bool is_signed = true) { - if (!is_signed && type.isSignlessInteger(8)) { - return tflite::TensorType_UINT8; - } if (!is_signed) { - return Status(absl::StatusCode::kInvalidArgument, - "'isSigned' can only be set for 8-bits integer type"); + if (type.isSignlessInteger(8)) { + return tflite::TensorType_UINT8; + } else if (type.isSignlessInteger(16)) { + return tflite::TensorType_UINT16; + } else { + return Status(absl::StatusCode::kInvalidArgument, + "'isSigned' can only be set for 8/16-bits integer type"); + } } if (type.isF32()) { @@ -535,14 +538,16 @@ class Translator { const std::unordered_set& tags, OpOrArgNameMapper* op_or_arg_name_mapper, const std::map& metadata, - bool serialize_stablehlo_ops); + bool serialize_stablehlo_ops, + std::optional custom_option_alignment); private: enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; explicit Translator(ModuleOp module, const toco::TocoFlags& toco_flags, const std::unordered_set& saved_model_tags, OpOrArgNameMapper* op_or_arg_name_mapper, - const std::map& metadata) + const std::map& metadata, + std::optional custom_option_alignment) : module_(module), name_mapper_(*op_or_arg_name_mapper), builder_(kInitialBufferSize), @@ -553,7 +558,8 @@ class Translator { metadata_(metadata), supported_backends_(toco_flags.supported_backends().begin(), toco_flags.supported_backends().end()), - use_buffer_offset_(toco_flags.use_buffer_offset()) { + use_buffer_offset_(toco_flags.use_buffer_offset()), + custom_option_alignment_(custom_option_alignment) { // The first buffer must be empty according to the schema definition. empty_buffer_ = tflite::CreateBuffer(builder_); buffers_.push_back(empty_buffer_); @@ -582,9 +588,10 @@ class Translator { // Returns TFLite buffer populated with constant value if the operation is // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns std::nullopt on failure. - std::optional> BuildBuffer(Value value, - int index); + // and returns std::nullopt on failure. The buffer index may be changed if + // duplicated buffer is found. + std::optional> BuildBuffer( + Value value, bool can_be_deduplicated, int& index); // Build TFLite tensor from the given type. This function is for tfl.lstm // intermediates, which should have UniformQuantizedType. @@ -675,11 +682,6 @@ class Translator { std::optional>> CreateMetadataVector(); - // Encodes the `tfl.metadata_buffer` array attribute of the module to the - // metadata_buffer section in the final model. Returns empty if there isn't - // such attribute in the mlir module. - VectorBufferOffset CreateMetadataBufferVector(); - // Builds and returns list of tfl.SignatureDef sections in the model. std::optional>> CreateSignatureDefs(const std::vector& signature_defs); @@ -751,6 +753,10 @@ class Translator { const std::vector& operands, const std::vector& results); + std::optional> BuildStablehloPadOp( + mlir::stablehlo::PadOp pad_op, const std::vector& operands, + const std::vector& results); + // create a subgraph given a unnamed mlir region, return the corresponding // subgraph index int32_t UnnamedRegionToSubgraph(mlir::Region* region, @@ -837,6 +843,12 @@ class Translator { bool use_buffer_offset_ = false; bool require_use_buffer_offset_ = false; + + std::optional custom_option_alignment_ = std::nullopt; + + // Map from mlir constant attribute to the buffer index. This is used to + // deduplicate the buffers in the flatbuffer. + llvm::DenseMap const_attribute_to_buffer_map_; }; bool Translator::EstimateArithmeticCount(int64_t* count) { @@ -860,7 +872,7 @@ std::string Translator::UniqueName(mlir::Value val) { } std::optional> Translator::BuildBuffer( - mlir::Value value, int index) { + mlir::Value value, bool can_be_deduplicated, int& index) { auto inst = value.getDefiningOp(); ElementsAttr attr; if (auto cst = dyn_cast(inst)) { @@ -883,6 +895,15 @@ std::optional> Translator::BuildBuffer( return empty_buffer_; } + if (can_be_deduplicated) { + if (const_attribute_to_buffer_map_.find(attr) != + const_attribute_to_buffer_map_.end()) { + index = const_attribute_to_buffer_map_[attr]; + return empty_buffer_; + } + const_attribute_to_buffer_map_[attr] = index; + } + // TF doesn't currently support 4-bit types (DT_INT4), so we'll run into // trouble calling ConvertToTensor(). For now, extract the tensor data from // ElementsAttr directly in this and read type from tflite::TensorType instead @@ -1168,6 +1189,13 @@ std::optional> Translator::BuildTensor( break; } } + // The value is used as a variable if produced by an op with "tfl.is_variable" + // attribute. This provides a hook for the user to represent the variable + // tensor in the MLIR level. + if (auto* inst = value.getDefiningOp(); + inst && inst->hasAttr("tfl.is_variable")) { + is_variable = true; + } bool has_rank = type.hasRank(); @@ -1296,11 +1324,16 @@ BufferOffset Translator::BuildCustomOperator( /*builtin_options=*/0, /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS); } + if (custom_option_alignment_.has_value()) { + builder_.ForceVectorAlignment(custom_option_vector.size(), sizeof(uint8_t), + custom_option_alignment_.value()); + } + auto custom_option_fbs_vector = + builder_.CreateVector(custom_option_vector); return tflite::CreateOperator( builder_, opcode_index, builder_.CreateVector(operands), builder_.CreateVector(results), tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - builder_.CreateVector(custom_option_vector), + /*builtin_options=*/0, custom_option_fbs_vector, tflite::CustomOptionsFormat_FLEXBUFFERS); } @@ -1603,6 +1636,30 @@ Translator::BuildStablehloRngBitGeneratorOp( rng_options.Union()); } +std::optional> Translator::BuildStablehloPadOp( + mlir::stablehlo::PadOp pad_op, const std::vector& operands, + const std::vector& results) { + std::string op_name = pad_op->getName().getStringRef().str(); + uint32_t opcode_index = + GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_PAD); + + auto edge_padding_low = + builder_.CreateVector(pad_op.getEdgePaddingLow().vec()); + auto edge_padding_high = + builder_.CreateVector(pad_op.getEdgePaddingHigh().vec()); + auto interior_padding = + builder_.CreateVector(pad_op.getInteriorPadding().vec()); + + auto pad_option = tflite::CreateStablehloPadOptions( + builder_, edge_padding_low, edge_padding_high, interior_padding); + + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0, 0, + tflite::CustomOptionsFormat_FLEXBUFFERS, 0, 0, 0, 0, + tflite::BuiltinOptions2_StablehloPadOptions, pad_option.Union()); +} + std::optional> Translator::BuildOperator( Operation* inst, std::vector operands, const std::vector& results, @@ -1704,6 +1761,9 @@ std::optional> Translator::BuildOperator( return BuildStablehloOperatorwithoutOptions( inst, operands, results, tflite::BuiltinOperator_STABLEHLO_MINIMUM); } + if (auto shlo_op = llvm::dyn_cast(inst)) { + return BuildStablehloPadOp(shlo_op, operands, results); + } // for ops don't have kernels, only serialize when conversion is set to true if (convert_stablehlo_) { if (auto shlo_op = llvm::dyn_cast(inst)) { @@ -1817,8 +1877,7 @@ std::optional> Translator::BuildOperator( uint32_t opcode_index = GetOpcodeIndex( op_name, tflite::BuiltinOperator_STABLEHLO_DYNAMIC_SLICE); - auto slice_sizes = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getSliceSizes())); + auto slice_sizes = builder_.CreateVector(shlo_op.getSliceSizes().vec()); auto dynamic_slice_option = tflite::CreateStablehloDynamicSliceOptions(builder_, slice_sizes); @@ -1854,27 +1913,6 @@ std::optional> Translator::BuildOperator( tflite::BuiltinOptions2_StablehloCompareOptions, compare_option.Union()); } - if (auto shlo_op = llvm::dyn_cast(inst)) { - std::string op_name = inst->getName().getStringRef().str(); - uint32_t opcode_index = - GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_PAD); - - auto edge_padding_low = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getEdgePaddingLowAttr())); - auto edge_padding_high = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getEdgePaddingHighAttr())); - auto interior_padding = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getInteriorPaddingAttr())); - - auto pad_option = tflite::CreateStablehloPadOptions( - builder_, edge_padding_low, edge_padding_high, interior_padding); - - return tflite::CreateOperator( - builder_, opcode_index, builder_.CreateVector(operands), - builder_.CreateVector(results), tflite::BuiltinOptions_NONE, 0, 0, - tflite::CustomOptionsFormat_FLEXBUFFERS, 0, 0, 0, 0, - tflite::BuiltinOptions2_StablehloPadOptions, pad_option.Union()); - } if (auto shlo_op = llvm::dyn_cast(inst)) { std::string op_name = inst->getName().getStringRef().str(); uint32_t opcode_index = GetOpcodeIndex( @@ -1895,12 +1933,11 @@ std::optional> Translator::BuildOperator( uint32_t opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_SLICE); - auto start_indices = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getStartIndicesAttr())); - auto limit_indices = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getLimitIndicesAttr())); - auto strides = builder_.CreateVector( - mlir::GetOptionalVector(shlo_op.getStridesAttr())); + auto start_indices = + builder_.CreateVector(shlo_op.getStartIndices().vec()); + auto limit_indices = + builder_.CreateVector(shlo_op.getLimitIndices().vec()); + auto strides = builder_.CreateVector(shlo_op.getStrides().vec()); auto slice_option = tflite::CreateStablehloSliceOptions( builder_, start_indices, limit_indices, strides); @@ -2172,8 +2209,7 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_TRANSPOSE); auto transpose_option = tflite::CreateStablehloTransposeOptions( - builder_, builder_.CreateVector(mlir::GetOptionalVector( - shlo_op.getPermutation()))); + builder_, builder_.CreateVector(shlo_op.getPermutation().vec())); return tflite::CreateOperator( builder_, opcode_index, builder_.CreateVector(operands), @@ -2394,22 +2430,31 @@ std::optional> Translator::BuildSubGraph( quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op); } } - auto tensor_or = - BuildTensor(value, tensor_name, buffers_.size(), quant_parameters); - if (!tensor_or) return false; - tensors.push_back(*tensor_or); + int buffer_index = buffers_.size(); + // If a constant is returned as subgraph's output, this constant cannot be + // deduplicated. + const bool not_returned_by_subgraph = llvm::none_of( + value.getUsers(), + [](Operation* user) { return llvm::isa(user); }); // TODO(ashwinm): Check if for stateful tensors, if it is also needed to // make the Buffer empty apart from setting the buffer_idx=0 in the // Tensor. This does not seem to affect runtime behavior for RNN/LSTM, // but would be good for reducing memory footprint. if (value.getDefiningOp()) { - auto buffer_or = BuildBuffer(value, buffers_.size()); + auto buffer_or = + BuildBuffer(value, not_returned_by_subgraph, buffer_index); if (!buffer_or) return false; buffers_.push_back(*buffer_or); } else { buffers_.push_back(empty_buffer_); } + + auto tensor_or = + BuildTensor(value, tensor_name, buffer_index, quant_parameters); + if (!tensor_or) return false; + tensors.push_back(*tensor_or); + return true; }; @@ -2625,18 +2670,6 @@ Translator::CreateMetadataVector() { return builder_.CreateVector(metadata); } -VectorBufferOffset Translator::CreateMetadataBufferVector() { - auto array_attr = - module_->getAttrOfType("tfl.metadata_buffer"); - std::vector metadata_buffer; - if (!array_attr) return 0; - for (auto value : array_attr.getAsValueRange()) { - metadata_buffer.push_back(value.getSExtValue()); - } - - return builder_.CreateVector(metadata_buffer); -} - // Helper method that returns list of all strings in a StringAttr identified // by 'attr_key' and values are separated by a comma. llvm::SmallVector GetStringsFromAttrWithSeparator( @@ -2824,21 +2857,23 @@ std::optional Translator::Translate( const std::unordered_set& tags, OpOrArgNameMapper* op_or_arg_name_mapper, const std::map& metadata, - bool serialize_stablehlo_ops) { + bool serialize_stablehlo_ops, + std::optional custom_option_alignment) { OpOrArgLocNameMapper default_op_or_arg_name_mapper; if (!op_or_arg_name_mapper) op_or_arg_name_mapper = &default_op_or_arg_name_mapper; if (!UpdateEntryFunction(module)) return std::nullopt; if (!IsValidTFLiteMlirModule(module)) return std::nullopt; Translator translator(module, toco_flags, tags, op_or_arg_name_mapper, - metadata); + metadata, custom_option_alignment); translator.convert_stablehlo_ = serialize_stablehlo_ops; auto ret = translator.TranslateInternal(); if (translator.require_use_buffer_offset_) { auto new_toco_flags = toco_flags; new_toco_flags.set_use_buffer_offset(true); Translator new_translator(module, new_toco_flags, tags, - op_or_arg_name_mapper, metadata); + op_or_arg_name_mapper, metadata, + custom_option_alignment); return new_translator.TranslateInternal(); } return ret; @@ -3039,8 +3074,7 @@ std::optional Translator::TranslateInternal() { // Build the model and finish the model building process. auto description = builder_.CreateString(model_description.data()); - VectorBufferOffset metadata_buffer = - CreateMetadataBufferVector(); // Deprecated + VectorBufferOffset metadata_buffer = 0; // Deprecated auto metadata = CreateMetadataVector(); if (!metadata) return std::nullopt; @@ -3131,6 +3165,10 @@ void Translator::AppendBufferData(std::string& result) { for (auto& it : custom_op_data_map_) { while (result.size() % 16 != 0) result += '\0'; + if (custom_option_alignment_.has_value()) { + while (result.size() % custom_option_alignment_.value() != 0) + result += '\0'; + } auto buffer = std::string(it.second.begin(), it.second.end()); int64_t offset = result.size(); int64_t size = it.second.size(); @@ -3345,7 +3383,8 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, bool serialize_stablehlo_ops) { auto maybe_translated = Translator::Translate( module, options.toco_flags, options.saved_model_tags, - options.op_or_arg_name_mapper, options.metadata, serialize_stablehlo_ops); + options.op_or_arg_name_mapper, options.metadata, serialize_stablehlo_ops, + options.custom_option_alignment); if (!maybe_translated) return false; *serialized_flatbuffer = std::move(*maybe_translated); return true; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h index b279c113c94a2a..cd461c96115375 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#include #include +#include #include #include @@ -42,6 +44,10 @@ struct FlatbufferExportOptions { // OpOrArgNameMapper to convert location of the op to name in flatbuffer. // If not set, a default mapper will be used. tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr; + // User-specified value of flatbuffer alignment requirement for custom + // options. If specified, the value should be multiplier of 16 (default + // alignment for TFL flatbuffer). + std::optional custom_option_alignment = std::nullopt; }; // Translates the given MLIR `module` into a FlatBuffer and stores the diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 6eb2aee99aacb9..69dd8ad342cfe1 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -604,10 +604,10 @@ static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index, } // TODO(b/172664358): Creates a new op instead of reusing constant op. -// Creates a constant op to represent stateful variable. The function static -// variable `stateful_variable_idx` is used as a unique value for each constant -// to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type` -// is the ShapedType for the const op. +// Creates a constant op with "tfl.is_variable" attribute to represent stateful +// variable. The function static variable `stateful_variable_idx` is used as a +// unique value for each constant to avoid CSEed. `tensor` is the data structure +// of flatbuffer. `shaped_type` is the ShapedType for the const op. StatusOr BuildVariableOp(const tflite::TensorT& tensor, OpBuilder builder, Location loc) { TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder, @@ -626,6 +626,7 @@ StatusOr BuildVariableOp(const tflite::TensorT& tensor, return op.getOperation(); } auto op = builder.create(loc, value); + op->setAttr("tfl.is_variable", builder.getUnitAttr()); if (tensor.quantization && !tensor.quantization->min.empty()) { if (auto stats_op = ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) { @@ -1904,11 +1905,6 @@ OwningOpRef tflite::FlatBufferToMlir( mlir::UnitAttr::get(builder.getContext())); } - if (!model->metadata_buffer.empty()) { - module->setAttr("tfl.metadata_buffer", - builder.getI32ArrayAttr(model->metadata_buffer)); - } - if (use_stablehlo_constant) { module->setAttr("tfl.metadata", builder.getDictionaryAttr(builder.getNamedAttr( diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 8d97a1e1f2b349..b51d1b1d7019c5 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -285,6 +285,14 @@ static mlir::Attribute BuildRankedTensorAttr(std::vector shape, return mlir::DenseIntElementsAttr::get(ty, value); } +static mlir::Attribute BuildI64ArrayAttr(std::vector shape, + std::vector value, + mlir::Builder builder) { + // Expand splats. BuildI64ArrayAttr assumes shape.size() == 1. + if (value.size() == 1) value.resize(shape[0], value[0]); + return builder.getDenseI64ArrayAttr(value); +} + static mlir::Attribute BuildF32ArrayAttr(std::vector value, mlir::Builder builder) { std::vector typecast(value.begin(), value.end()); @@ -400,13 +408,11 @@ void BuiltinOptions2ToAttributesManual( std::vector shape = { static_cast(op->start_indices.size())}; attributes.emplace_back(builder.getNamedAttr( - "start_indices", - BuildRankedTensorAttr(shape, op->start_indices, builder))); + "start_indices", BuildI64ArrayAttr(shape, op->start_indices, builder))); attributes.emplace_back(builder.getNamedAttr( - "limit_indices", - BuildRankedTensorAttr(shape, op->limit_indices, builder))); + "limit_indices", BuildI64ArrayAttr(shape, op->limit_indices, builder))); attributes.emplace_back(builder.getNamedAttr( - "strides", BuildRankedTensorAttr(shape, op->strides, builder))); + "strides", BuildI64ArrayAttr(shape, op->strides, builder))); return; } if (const auto* op = op_union.AsStablehloConvolutionOptions()) { @@ -496,20 +502,20 @@ void BuiltinOptions2ToAttributesManual( static_cast(op->edge_padding_low.size())}; attributes.emplace_back(builder.getNamedAttr( "edge_padding_low", - BuildRankedTensorAttr(shape, op->edge_padding_low, builder))); + BuildI64ArrayAttr(shape, op->edge_padding_low, builder))); attributes.emplace_back(builder.getNamedAttr( "edge_padding_high", - BuildRankedTensorAttr(shape, op->edge_padding_high, builder))); + BuildI64ArrayAttr(shape, op->edge_padding_high, builder))); attributes.emplace_back(builder.getNamedAttr( "interior_padding", - BuildRankedTensorAttr(shape, op->interior_padding, builder))); + BuildI64ArrayAttr(shape, op->interior_padding, builder))); return; } if (const auto* op = op_union.AsStablehloDynamicSliceOptions()) { attributes.emplace_back(builder.getNamedAttr( "slice_sizes", - BuildRankedTensorAttr({static_cast(op->slice_sizes.size())}, - op->slice_sizes, builder))); + BuildI64ArrayAttr({static_cast(op->slice_sizes.size())}, + op->slice_sizes, builder))); return; } if (const auto* op = op_union.AsStablehloCompareOptions()) { @@ -623,8 +629,8 @@ void BuiltinOptions2ToAttributesManual( if (!op->permutation.empty()) { attributes.emplace_back(builder.getNamedAttr( "permutation", - BuildRankedTensorAttr({static_cast(op->permutation.size())}, - op->permutation, builder))); + BuildI64ArrayAttr({static_cast(op->permutation.size())}, + op->permutation, builder))); } return; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td new file mode 100644 index 00000000000000..d9200ddc70f112 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td @@ -0,0 +1,56 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the optimization pattern definition file for TensorFlow Lite. + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "tensorflow/compiler/mlir/lite/utils/utils.td" + +// Returns Squeezed shape of a ranked-tensor. +// Squeezed, here, means eliminating any 1s' in the +// dimensions of the tensor +def GetSqueezedShape: NativeCodeCall<"GetSqueezedShape($0)">; + +// This is a utility function to deduct the effective permutation to apply on +// TFL_TransposeOp when the tensor has some dimensions with value==1 +def GetSqueezedPermutation: NativeCodeCall<"GetSqueezedPermutation($0, $1)">; + +// Check to see if the tensor dimensions can be Squeezed by eliminating 1s' +def CanSqueezeTensor : Constraint GetSqueezedShape($0).getNumElements()">>; + + +// Pattern to convert TFL_TransposeOp with rank>6 to rank<=6 if there are +// redundant dimensions in the tensor. For example- [2x1x3] == [2x3] and 1 is +// not contributing to the dimentionality. This will run if the rank>6 +// Pattern will convert- +// %0 = "tfl.transpose"(%arg0, %cst) : (tensor<56x8x56x1x1x1x7xf32>, tensor<7xi32>) -> tensor<1x1x8x56x56x7x1xf32> +// to- +// %0 = "tfl.reshape"(%arg0, %cst) : (tensor<56x8x56x1x1x1x7xf32>, tensor<4xi32>) -> tensor<56x8x56x7xf32> +// %1 = "tfl.transpose"(%0, %cst_0) : (tensor<56x8x56x7xf32>, tensor<4xi32>) -> tensor<8x56x56x7xf32> +// %2 = "tfl.reshape"(%1, %cst_1) : (tensor<8x56x56x7xf32>, tensor<7xi32>) -> tensor<1x1x8x56x56x7x1xf32> +def ConvertTransposeToDecreaseRank : Pat< + (TFL_TransposeOp:$output_transpose $input, (Arith_ConstantOp:$permutation $_)), + (TFL_ReshapeOp + (TFL_TransposeOp + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetSqueezedShape $input))), + (Arith_ConstantOp (GetSqueezedPermutation $input, $permutation))), + (Arith_ConstantOp (GetShape $output_transpose))), + [(AnyStaticShapeTensor $input), + (HasRankAtLeast<7> $input), + (CanSqueezeTensor $input)]>; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 939d840f404445..779f8580c7144a 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/escaping.h" #include "Eigen/Core" // from @eigen_archive #include "llvm/ADT/APFloat.h" @@ -146,6 +147,74 @@ Operation* getDefiningBroadcastArgsOp(Value operand) { } return parent_of_defining_op; } + +// Returns shape of a ranked tensor. +// Precondition: value_tensor's is ranked tensor. +// Returns a Squeezed shape. Truncation here means eliminating the redundant +// dimensions 1. +DenseElementsAttr GetSqueezedShape(Value value_tensor) { + auto value_shape_type = value_tensor.getType().dyn_cast(); + assert(value_shape_type.hasRank() && "value_tensor should be ranked tensor"); + + auto value_shape = value_shape_type.getShape(); + SmallVector return_squeeze_shape; + return_squeeze_shape.reserve(value_shape.size()); + + for (size_t dim_idx = 0; dim_idx < value_shape.size(); ++dim_idx) { + int64_t dim = value_shape[dim_idx]; + if (dim == 1) { + continue; + } + return_squeeze_shape.push_back( + ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(return_squeeze_shape.size())}, + mlir::IntegerType::get(value_tensor.getContext(), 32)), + llvm::ArrayRef(return_squeeze_shape)); +} + +// This is a utility function to deduce the effective permutation to apply on +// TFL_TransposeOp when the tensor has some dimensions with value==1 +// Example- "tfl.transpose"(tensor<56x8x56x1x1x1x7xf32>, [4, 5, 1, 2, 0, 6, 3]) +// Permutation before squeese is [4, 5, 1, 2, 0, 6, 3] becomes [1, 2, 0, 3] +// after squeeze is perfomed to retain the relative ordering of the non-1 dims. +DenseElementsAttr GetSqueezedPermutation(Value input_value, + Value input_permutation) { + auto input_shape = input_value.getType().dyn_cast().getShape(); + absl::flat_hash_map permutation_map; + + for (size_t before_dim_idx = 0, after_dim_idx = 0; + before_dim_idx < input_shape.size(); ++before_dim_idx) { + if (input_shape[before_dim_idx] == 1) { + continue; + } + permutation_map.insert({before_dim_idx, after_dim_idx++}); + } + + SmallVector squeezed_permutation; + DenseElementsAttr input_perm_const; + if (matchPattern(input_permutation, m_Constant(&input_perm_const))) { + for (int32_t idx = 0; idx < input_perm_const.getNumElements(); ++idx) { + size_t perm = input_perm_const.getValues()[idx].getSExtValue(); + if (input_shape[perm] == 1) { + continue; + } + squeezed_permutation.push_back(permutation_map[perm]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(squeezed_permutation.size())}, + mlir::IntegerType::get(input_permutation.getContext(), 32)), + llvm::ArrayRef(squeezed_permutation)); +} + +#include "tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.inc" + } // namespace // Returns true when the given type lists contain a single element of shaped @@ -3447,6 +3516,11 @@ void ComputePermutation(ArrayRef perms, ArrayRef output_shape, } // namespace +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} + OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 73740be2310ef7..380301a9cbee40 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -50,6 +50,7 @@ typedef TFLDialect TensorFlowLiteDialect; class ControlType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = "tfl.control"; }; #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 4b915afcb0603d..45a1e3b25e1335 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -305,6 +305,14 @@ class TFL_OperandHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", TFL_OperandHasRankAtMostPred>; +// Not all dimentions in the tensor will contribute to the data move in a +// TransposeOp. Effective rank is the number of dimentions != 1 +class TFL_TransposeOperandHasEffectiveRankAtMost : + PredOpTrait<"operand " # n # " is at most " # m # "-D", + Or<[TFL_OperandIsUnrankedPred, + CPred<"GetSqueezedShape($_op.getOperand(" # n # + ")).cast().size() <= " # m>]>>; + class TFL_OperandHasRankAtLeast : PredOpTrait<"operand " # n # " is at least " # m # "-D", Or<[TFL_OperandIsUnrankedPred, @@ -1211,12 +1219,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$params, + TFL_TensorOf<[F32, I1, I8, I16, I64, I32, UI8, TFL_Str]>:$params, TFL_TensorOf<[I16, I32, I64]>:$indices ); let results = (outs - TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$output + TFL_TensorOf<[F32, I1, I8, I16, I64, I32, UI8, TFL_Str]>:$output ); } @@ -3488,7 +3496,7 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [ def TFL_TransposeOp : TFL_Op<"transpose", [ Pure, QuantizableResult, - TFL_OperandHasRankAtMost<0, 6>, + TFL_TransposeOperandHasEffectiveRankAtMost<0, 6>, TFL_OperandHasRank<1, 1>, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, @@ -3512,6 +3520,8 @@ def TFL_TransposeOp : TFL_Op<"transpose", [ let hasFolder = 1; + let hasCanonicalizer = 1; + let builders = [ OpBuilder<(ins "Value":$input, "Value":$perm), [{ BuildTransposeOp(&$_builder, $_state, input, perm); }]> diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 57a5c93556c4cf..62c2733d2b510c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -24,10 +25,12 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -777,6 +780,14 @@ void QuantizationDriver::PreprocessConstantOps() { auto type = cst.getType().dyn_cast(); if (!type || !type.getElementType().isa()) return; + // Skip if the value is NaN or INF. + // Otherwise the illegal scale/zp will be calculated. + auto float_attr = cst.getValueAttr().dyn_cast(); + if (float_attr) { + auto cst_float_falue = float_attr.getValues()[0]; + if (!cst_float_falue.isFinite()) return; + } + Value value = cst.getResult(); builder_.setInsertionPoint(cst); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 2459f3d214d13a..152b48b1f9043a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -454,6 +454,7 @@ cc_library( deps = [ ":passes_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/quantization/stablehlo:uniform_quantized_types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", @@ -523,7 +524,6 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:custom_call", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:dot_general", - "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:ArithDialect", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir index f5f69b1cf18340..bffb1da2b07117 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir @@ -3,9 +3,9 @@ module { func.func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { %0 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_low = dense<[1, 0]> : tensor<2xi64>, - edge_padding_high = dense<[2, 3]> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> + edge_padding_low = array, + edge_padding_high = array, + interior_padding = array } : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> func.return %0 : tensor<11x131xf32> } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir index 482a7f9e176977..1d47c5c6382837 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir @@ -9,8 +9,8 @@ module { // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { -// CHECK-NEXT: %0 = stablehlo.pad %arg0, %arg1, low = [1, 0], high = [2, 3], interior = [0, 0] : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> -// CHECK-NEXT: return %0 : tensor<11x131xf32> -// CHECK-NEXT: } +// CHECK-NEXT: %0 = stablehlo.pad %arg0, %arg1, low = [1, 0], high = [2, 3], interior = [0, 0] : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> +// CHECK-NEXT: return %0 : tensor<11x131xf32> // CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index 593cdbf4fa8b4d..b2948e59fae0b7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -1786,7 +1786,7 @@ func.func @round_nearest_even(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-SAME: %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<1xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64> // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> -// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_5:.*]] = arith.constant dense<1> : tensor<1xi64> // CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> // CHECK: return %[[VAL_6]] : tensor<1xf32> @@ -1803,7 +1803,7 @@ func.func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) - // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> // CHECK: %[[VAL_4:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64> // CHECK: %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> -// CHECK: %[[VAL_6:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_5]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_5]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = arith.constant dense<> : tensor<0xi64> // CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor // CHECK: return %[[VAL_8]] : tensor @@ -1816,7 +1816,7 @@ func.func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> // CHECK-LABEL: func @convert_dot_2d_2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { -// CHECK: %[[VAL_2:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_2:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_1]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> // CHECK: return %[[VAL_2]] : tensor<1x1xf32> // CHECK: } func.func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { @@ -1895,7 +1895,7 @@ func.func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(%arg0: tensor // CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_6]]) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> // CHECK: %[[VAL_8:.*]] = arith.constant dense<[3, 12, 4]> : tensor<3xi64> // CHECK: %[[VAL_9:.*]] = "tf.Reshape"(%[[VAL_5]], %[[VAL_8]]) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> -// CHECK: %[[VAL_10:.*]] = "tf.BatchMatMulV3"(%[[VAL_7]], %[[VAL_9]]) <{adj_x = false, adj_y = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: %[[VAL_10:.*]] = "tf.BatchMatMulV3"(%[[VAL_7]], %[[VAL_9]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> // CHECK: %[[VAL_11:.*]] = arith.constant dense<[3, 5, 1, 4]> : tensor<4xi64> // CHECK: %[[VAL_12:.*]] = "tf.Reshape"(%[[VAL_10]], %[[VAL_11]]) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> // CHECK: return %[[VAL_12]] : tensor<3x5x1x4xf32> @@ -1929,7 +1929,7 @@ func.func @quantized_dot_general_not_converted(%arg0: tensor<1x1x512xf32>, %arg1 // CHECK-SAME: %[[VAL_1:.*]]: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<[1, 1024]> : tensor<2xi64> // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<1x1024xf32> -// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : {{.*}} -> tensor<1x1024xf32> +// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : {{.*}} -> tensor<1x1024xf32> // CHECK: %[[VAL_5:.*]] = arith.constant dense<[1, 1, 1024]> : tensor<3xi64> // CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : {{.*}} -> tensor<1x1x1024xf32> // CHECK: return %[[VAL_6]] : tensor<1x1x1024xf32> @@ -1952,7 +1952,7 @@ func.func @convert_dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tens // CHECK-SAME: %[[VAL_1:.*]]: tensor<256x8xi8>) -> tensor<8xi32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<[1, 256]> : tensor<2xi64> // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xi8>, tensor<2xi64>) -> tensor<1x256xi8> -// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false}> : (tensor<1x256xi8>, tensor<256x8xi8>) -> tensor<1x8xi32> +// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xi8>, tensor<256x8xi8>) -> tensor<1x8xi32> // CHECK: %[[VAL_5:.*]] = arith.constant dense<8> : tensor<1xi64> // CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x8xi32>, tensor<1xi64>) -> tensor<8xi32> // CHECK: return %[[VAL_6]] : tensor<8xi32> @@ -1982,7 +1982,7 @@ func.func @convert_dot_general_int8(%arg0: tensor<256xi8>, %arg1: tensor<256x8xi // CHECK-DAG: %cst_4 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %4 = "tf.Concat"(%cst_4, %cst_3, %3, %2) : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK: %5 = "tf.Reshape"(%0, %4) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> -// CHECK: %6 = "tf.BatchMatMulV3"(%arg0, %5) <{adj_x = false, adj_y = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> +// CHECK: %6 = "tf.BatchMatMulV3"(%arg0, %5) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> // CHECK: %7 = "tf.Shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32> // CHECK: %8 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> // CHECK-DAG: %cst_5 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> @@ -2032,7 +2032,7 @@ func.return %0 : tensor<4x4x?xf32> // CHECK-DAG: %cst_9 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %11 = "tf.Concat"(%cst_9, %10, %9, %8) : (tensor, tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> // CHECK: %12 = "tf.Reshape"(%0, %11) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> -// CHECK: %13 = "tf.BatchMatMulV3"(%6, %12) <{adj_x = false, adj_y = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> +// CHECK: %13 = "tf.BatchMatMulV3"(%6, %12) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> // CHECK: %14 = "tf.Shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> // CHECK: %15 = "tf.Shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> // CHECK-DAG: %cst_10 = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> @@ -2080,7 +2080,7 @@ func.return %0 : tensor<2x?x2x4xf32> // CHECK-DAG: %cst_9 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %9 = "tf.Concat"(%cst_9, %cst_8, %8, %7) : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK: %10 = "tf.Reshape"(%0, %9) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> -// CHECK: %11 = "tf.BatchMatMulV3"(%5, %10) <{adj_x = false, adj_y = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> +// CHECK: %11 = "tf.BatchMatMulV3"(%5, %10) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> // CHECK: %12 = "tf.Shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> // CHECK: %13 = "tf.Shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> // CHECK-DAG: %cst_10 = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> @@ -2126,7 +2126,7 @@ func.return %0 : tensor<2x2x?x4x?xf32> // CHECK-DAG: %cst_8 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %8 = "tf.Concat"(%cst_8, %cst_7, %7, %6) : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK: %9 = "tf.Reshape"(%arg1, %8) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32> -// CHECK: %10 = "tf.BatchMatMulV3"(%4, %9) <{adj_x = false, adj_y = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> +// CHECK: %10 = "tf.BatchMatMulV3"(%4, %9) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> // CHECK: return %10 : tensor<4x4x256xf32> // CHECK: } func.func @convert_dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, %arg1: tensor<4x?x256xf32>) -> tensor<4x4x256xf32> { @@ -3742,6 +3742,26 @@ func.func @convert_gather(%arg0: tensor<147456xf16>, %arg1: tensor<192x256x1xi32 func.return %0 : tensor<192x256xf16> } +// CHECK-LABEL: func @convert_gather_with_ui32indices( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<147456xf16>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<192x256x1xui32>) +// CHECK: %[[INDICES:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<192x256x1xui32>) -> tensor<192x256x1xi64> +// CHECK: %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[INDICES]]) : {{.*}} -> tensor<192x256xf16> +// CHECK: return %[[VAL_0]] +// CHECK: } +func.func @convert_gather_with_ui32indices(%arg0: tensor<147456xf16>, %arg1: tensor<192x256x1xui32>) -> tensor<192x256xf16> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + collapsed_slice_dims = [0], + index_vector_dim = 2, + start_index_map = [0], + >, + indices_are_sorted = false, + slice_sizes = dense<1> : tensor<1xi64> + } : (tensor<147456xf16>, tensor<192x256x1xui32>) -> tensor<192x256xf16> + func.return %0 : tensor<192x256xf16> +} + // CHECK-LABEL: func @convert_gather_nd( // CHECK-SAME: %[[VAL_0:.*]]: tensor<98x128xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x64xi32>) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_pad.mlir deleted file mode 100644 index b72b4296c000ff..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo_pad.mlir +++ /dev/null @@ -1,175 +0,0 @@ -// RUN: odml-to-stablehlo-opt %s -tfl-legalize-hlo -split-input-file | FileCheck %s --dump-input=fail - -func.func @mhlo_pad_test__noop(%input: tensor<5x7xf32>, %padding_value: tensor) -> tensor<5x7xf32> { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[0, 0]> : tensor<2xi64>, - edge_padding_high = dense<[0, 0]> : tensor<2xi64>, - interior_padding = dense<[0, 0]> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<5x7xf32> - func.return %0: tensor<5x7xf32> - -// CHECK-LABEL: mhlo_pad_test__noop -// CHECK: return %arg0 : tensor<5x7xf32> -} - -func.func @mhlo_pad_test__pad_all(%input: tensor<5x7xf32>, %padding_value: tensor) -> tensor<9x10xf32> { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[3, 2]> : tensor<2xi64>, - edge_padding_high = dense<[1, 1]> : tensor<2xi64>, - interior_padding = dense<[0, 0]> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<9x10xf32> - func.return %0: tensor<9x10xf32> - -// CHECK-LABEL: mhlo_pad_test__pad_all -// CHECK: %cst = arith.constant dense<{{\[}}[3, 1], [2, 1]]> : tensor<2x2xi64> -// CHECK: %0 = "tfl.padv2"(%arg0, %cst, %arg1) : (tensor<5x7xf32>, tensor<2x2xi64>, tensor) -> tensor<9x10xf32> -// CHECK: return %0 : tensor<9x10xf32> -} - -func.func @mhlo_pad_test__crop_all(%input: tensor<5x7xf32>, %padding_value: tensor) -> tensor<3x5xf32> { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[-1, -1]> : tensor<2xi64>, - edge_padding_high = dense<[-1, -1]> : tensor<2xi64>, - interior_padding = dense<[0, 0]> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<3x5xf32> - func.return %0: tensor<3x5xf32> - -// CHECK-LABEL: mhlo_pad_test__crop_all -// CHECK: %cst = arith.constant dense<1> : tensor<2xi64> -// CHECK: %cst_0 = arith.constant dense<-1> : tensor<2xi64> -// CHECK: %cst_1 = arith.constant dense<1> : tensor<2xi64> -// CHECK: %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<5x7xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x5xf32> -// CHECK: return %0 : tensor<3x5xf32> -} - -func.func @mhlo_pad_test__interior_pad_all(%input: tensor<5x7xf32>, %padding_value: tensor) -> tensor<9x13xf32> { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[0, 0]> : tensor<2xi64>, - edge_padding_high = dense<[0, 0]> : tensor<2xi64>, - interior_padding = dense<[1, 1]> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<9x13xf32> - func.return %0: tensor<9x13xf32> - -// CHECK-LABEL: mhlo_pad_test__interior_pad_all -// CHECK: %cst = arith.constant dense<2> : tensor<2xi32> -// CHECK: %0 = "tfl.dilate"(%arg0, %cst, %arg1) : (tensor<5x7xf32>, tensor<2xi32>, tensor) -> tensor<9x13xf32> -// CHECK: return %0 : tensor<9x13xf32> -} - -func.func @mhlo_pad_test__pad_and_crop(%input: tensor<5x7xf32>, %padding_value: tensor) -> tensor<5x7xf32> { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[-1, 1]> : tensor<2xi64>, - edge_padding_high = dense<[1, -1]> : tensor<2xi64>, - interior_padding = dense<[0, 0]> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<5x7xf32> - func.return %0: tensor<5x7xf32> - -// CHECK-LABEL: mhlo_pad_test__pad_and_crop -// CHECK: %cst = arith.constant dense<{{\[}}[0, 1], [1, 0]]> : tensor<2x2xi64> -// CHECK: %0 = "tfl.padv2"(%arg0, %cst, %arg1) : (tensor<5x7xf32>, tensor<2x2xi64>, tensor) -> tensor<6x8xf32> -// CHECK: %cst_0 = arith.constant dense<[1, 0]> : tensor<2xi64> -// CHECK: %cst_1 = arith.constant dense<[0, -1]> : tensor<2xi64> -// CHECK: %cst_2 = arith.constant dense<1> : tensor<2xi64> -// CHECK: %1 = "tfl.strided_slice"(%0, %cst_0, %cst_1, %cst_2) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<6x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x7xf32> -// CHECK: return %1 : tensor<5x7xf32> -} - -func.func @mhlo_pad_test__pad_and_crop_and_interior_pad(%input: tensor<5x7xf32>, %padding_value: tensor) -> tensor<13x25xf32> { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[-1, 1]> : tensor<2xi64>, - edge_padding_high = dense<[1, -1]> : tensor<2xi64>, - interior_padding = dense<[2, 3]> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<13x25xf32> - func.return %0: tensor<13x25xf32> - -// CHECK-LABEL: mhlo_pad_test__pad_and_crop_and_interior_pad -// CHECK: %cst = arith.constant dense<[3, 4]> : tensor<2xi32> -// CHECK: %0 = "tfl.dilate"(%arg0, %cst, %arg1) : (tensor<5x7xf32>, tensor<2xi32>, tensor) -> tensor<13x25xf32> -// CHECK: %cst_0 = arith.constant dense<{{\[}}[0, 1], [1, 0]]> : tensor<2x2xi64> -// CHECK: %1 = "tfl.padv2"(%0, %cst_0, %arg1) : (tensor<13x25xf32>, tensor<2x2xi64>, tensor) -> tensor<14x26xf32> -// CHECK: %cst_1 = arith.constant dense<[1, 0]> : tensor<2xi64> -// CHECK: %cst_2 = arith.constant dense<[0, -1]> : tensor<2xi64> -// CHECK: %cst_3 = arith.constant dense<1> : tensor<2xi64> -// CHECK: %2 = "tfl.strided_slice"(%1, %cst_1, %cst_2, %cst_3) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<14x26xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<13x25xf32> -// CHECK: return %2 : tensor<13x25xf32> -} - -func.func @mhlo_pad_test__pad_all_unknown_shape(%input: tensor, %padding_value: tensor) -> tensor { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[1, 1, 1, 1]> : tensor<4xi64>, - edge_padding_high = dense<[1, 1, 1, 1]> : tensor<4xi64>, - interior_padding = dense<[0, 0, 0, 0]> : tensor<4xi64> - } : (tensor, tensor) -> tensor - func.return %0: tensor - -// CHECK-LABEL: mhlo_pad_test__pad_all_unknown_shape -// CHECK: %cst = arith.constant dense<1> : tensor<4x2xi64> -// CHECK: %0 = "tfl.padv2"(%arg0, %cst, %arg1) : (tensor, tensor<4x2xi64>, tensor) -> tensor -// CHECK: return %0 : tensor -} - -func.func @mhlo_pad_test__crop_all_unknown_shape(%input: tensor, %padding_value: tensor) -> tensor { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[-1, -1, -1, -1]> : tensor<4xi64>, - edge_padding_high = dense<[-1, -1, -1, -1]> : tensor<4xi64>, - interior_padding = dense<[0, 0, 0, 0]> : tensor<4xi64> - } : (tensor, tensor) -> tensor - func.return %0: tensor - -// CHECK-LABEL: mhlo_pad_test__crop_all_unknown_shape -// CHECK: %cst = arith.constant dense<1> : tensor<4xi64> -// CHECK: %cst_0 = arith.constant dense<-1> : tensor<4xi64> -// CHECK: %cst_1 = arith.constant dense<1> : tensor<4xi64> -// CHECK: %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor -// CHECK: return %0 : tensor -} - -func.func @mhlo_pad_test__pad_all_unknown_dim0(%input: tensor, %padding_value: tensor) -> tensor { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[1, 1, 1, 1]> : tensor<4xi64>, - edge_padding_high = dense<[1, 1, 1, 1]> : tensor<4xi64>, - interior_padding = dense<[0, 0, 0, 0]> : tensor<4xi64> - } : (tensor, tensor) -> tensor - func.return %0: tensor - -// CHECK-LABEL: mhlo_pad_test__pad_all_unknown_dim0 -// CHECK: %cst = arith.constant dense<1> : tensor<4x2xi64> -// CHECK: %0 = "tfl.padv2"(%arg0, %cst, %arg1) : (tensor, tensor<4x2xi64>, tensor) -> tensor -// CHECK: return %0 : tensor -} - -func.func @mhlo_pad_test__crop_all_unknown_dim0(%input: tensor, %padding_value: tensor) -> tensor { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[-1, -1, -1, -1]> : tensor<4xi64>, - edge_padding_high = dense<[-1, -1, -1, -1]> : tensor<4xi64>, - interior_padding = dense<[0, 0, 0, 0]> : tensor<4xi64> - } : (tensor, tensor) -> tensor - func.return %0: tensor - -// CHECK-LABEL: mhlo_pad_test__crop_all_unknown_dim0 -// CHECK: %cst = arith.constant dense<1> : tensor<4xi64> -// CHECK: %cst_0 = arith.constant dense<-1> : tensor<4xi64> -// CHECK: %cst_1 = arith.constant dense<1> : tensor<4xi64> -// CHECK: %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor -// CHECK: return %0 : tensor -} - -func.func @mhlo_pad_test__pad_and_crop_and_interior_pad_unknown_dim0(%input: tensor, %padding_value: tensor) -> tensor { - %0 = "mhlo.pad"(%input, %padding_value) { - edge_padding_low = dense<[-2, -1, 0, 1]> : tensor<4xi64>, - edge_padding_high = dense<[1, 0, -1, -2]> : tensor<4xi64>, - interior_padding = dense<[1, 2, 3, 4]> : tensor<4xi64> - } : (tensor, tensor) -> tensor - func.return %0: tensor - -// CHECK-LABEL: mhlo_pad_test__pad_and_crop_and_interior_pad_unknown_dim0 -// CHECK: %cst = arith.constant dense<[2, 3, 4, 5]> : tensor<4xi32> -// CHECK: %0 = "tfl.dilate"(%arg0, %cst, %arg1) : (tensor, tensor<4xi32>, tensor) -> tensor -// CHECK: %cst_0 = arith.constant dense<{{\[}}[0, 1], [0, 0], [0, 0], [1, 0]]> : tensor<4x2xi64> -// CHECK: %1 = "tfl.padv2"(%0, %cst_0, %arg1) : (tensor, tensor<4x2xi64>, tensor) -> tensor -// CHECK: %cst_1 = arith.constant dense<[2, 1, 0, 0]> : tensor<4xi64> -// CHECK: %cst_2 = arith.constant dense<[0, 0, -1, -2]> : tensor<4xi64> -// CHECK: %cst_3 = arith.constant dense<1> : tensor<4xi64> -// CHECK: %2 = "tfl.strided_slice"(%1, %cst_1, %cst_2, %cst_3) {begin_mask = 12 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor -// CHECK: return %2 : tensor -} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir index 073f31e39786d9..ef637a848461d9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir @@ -133,7 +133,7 @@ func.func @batchNormTraining_4D_middle_features( %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>) -> (tensor<3x4x256x6xf32>) { // CHECK-DAG: %[[CST_AXIS:.+]] = "tf.Const"() <{value = dense<[0, 1, 3]> : tensor<3xi32>}> : () -> tensor<3xi32> - // CHECK-DAG: %[[X_SHAPE:.+]] = shape.const_shape [3, 4, 256, 6] : tensor<4xindex> + // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor<3x4x256x6xf32> -> tensor<4xindex> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf32> // CHECK-DAG: %[[MEAN:.+]] = "tf.Mean"(%arg0, %[[CST_AXIS]]) <{keep_dims = false}> : (tensor<3x4x256x6xf32>, tensor<3xi32>) -> tensor<256xf32> // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>, tensor<4xindex>) -> tensor<3x4x256x6xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 15b3e37326cfe0..7272b5f17301cb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -24,7 +24,7 @@ func.func @uniform_quantize_op_quantized_input(%arg: tensor<2x2x!quant.uniform) -> tensor<2x // ----- -// Tests that the pattern doesn't match when the output tensor's sotrage type +// Tests that the pattern doesn't match when the output tensor's storage type // is i32. i32 storage type for quantized type is not compatible with // `tfl.quantize`. @@ -104,8 +104,8 @@ func.func @uniform_dequantize_op_return_f64(%arg: tensor<2x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer +func.func @convolution_upstream_full_integer(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %1 : tensor<1x3x3x2x!quant.uniform> @@ -123,8 +123,8 @@ func.func @convolution_op(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_non_const_filter +func.func @convolution_upstream_full_integer_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %0 : tensor<1x3x3x2x!quant.uniform> } @@ -139,8 +139,8 @@ func.func @convolution_op_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform< // Test that if the window padding contains values of 0, tfl.pad op is not // created and the `padding` attribute is set as "VALID". -// CHECK-LABEL: convolution_op_valid_padding -func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_valid_padding +func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 0], [0, 0]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> return %1 : tensor<1x1x1x2x!quant.uniform> @@ -157,8 +157,8 @@ func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_valid_padding +func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The `window` attribute is empty. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> @@ -175,8 +175,8 @@ func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_strides +func.func @convolution_upstream_full_integer_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The stride value is explicitly set to [1, 2]. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 2], pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> @@ -195,8 +195,8 @@ func.func @convolution_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_asym_input +func.func @dot_general_upstream_full_integer_asym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -216,8 +216,8 @@ func.func @dot_general_full_integer_asym_input(%arg0: tensor<1x2x3x4x!quant.unif // Test full integer quantized dot_general with symmetric quantized input. -// CHECK-LABEL: dot_general_full_integer_sym_input -func.func @dot_general_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_sym_input +func.func @dot_general_upstream_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -237,10 +237,34 @@ func.func @dot_general_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.unifo // ----- +// Tests that the pattern does not match when the output tensor's storage +// type is i32. Currently we support qi8, qi8 -> qi8 only for GEMM ops that +// are quantized upstream. Other cases should be handled by regular quantized +// stablehlo.dot_general case. + +// CHECK-LABEL: dot_general_upstream_full_integer_i32_output +func.func @dot_general_upstream_full_integer_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + %1 = "stablehlo.dot_general"(%arg0, %0) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + return %1 : tensor<1x2x3x5x!quant.uniform> +} +// CHECK: stablehlo.dot_general +// CHECK-NOT: tfl.quantize + +// ----- + // Test full integer quantized dot_general with activation as RHS -// CHECK-LABEL: dot_general_full_integer_activation_rhs -func.func @dot_general_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_activation_rhs +func.func @dot_general_upstream_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -258,8 +282,8 @@ func.func @dot_general_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant. // Test full integer quantized dot_general with adj_x -// CHECK-LABEL: dot_general_full_integer_adj_x -func.func @dot_general_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_adj_x +func.func @dot_general_upstream_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -282,8 +306,8 @@ func.func @dot_general_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_adj_y +func.func @dot_general_upstream_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -306,8 +330,8 @@ func.func @dot_general_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_too_many_batches +func.func @dot_general_upstream_full_integer_too_many_batches(%arg0: tensor<1x1x1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x1x1x2x4x5xi8>} : () -> tensor<1x1x1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -328,8 +352,8 @@ func.func @dot_general_full_integer_too_many_batches(%arg0: tensor<1x1x1x2x3x4x! // Test full integer quantized dot_general with too many contracting dimension -// CHECK-LABEL: dot_general_full_integer_too_many_contractions -func.func @dot_general_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_too_many_contractions +func.func @dot_general_upstream_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x4x5xi8>} : () -> tensor<1x2x4x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -350,8 +374,8 @@ func.func @dot_general_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x // Test full integer quantized dot_general with unsupported contracting dim -// CHECK-LABEL: dot_general_full_integer_wrong_contracting -func.func @dot_general_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_wrong_contracting +func.func @dot_general_upstream_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -373,8 +397,8 @@ func.func @dot_general_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!qua // Test full integer quantized dot_general with float operands -// CHECK-LABEL: dot_general_full_integer_float_operands -func.func @dot_general_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { +// CHECK-LABEL: dot_general_upstream_full_integer_float_operands +func.func @dot_general_upstream_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -394,8 +418,8 @@ func.func @dot_general_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, % // Test full integer quantized dot_general with asymmetric weight (rhs). -// CHECK-LABEL: dot_general_full_integer_asym_weight -func.func @dot_general_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_asym_weight +func.func @dot_general_upstream_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> @@ -409,8 +433,8 @@ func.func @dot_general_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uni // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized, it is converted to `tfl.fully_connected` op. -// CHECK-LABEL: dot_general_per_axis_quantized_filter -func.func @dot_general_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> return %1 : tensor<1x2x!quant.uniform> @@ -428,8 +452,8 @@ func.func @dot_general_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.unifor // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dimension, it is not converted. -// CHECK-LABEL: dot_general_per_axis_quantized_filter_with_batch_dim -func.func @dot_general_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x1x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> return %1 : tensor<1x1x2x!quant.uniform> @@ -444,8 +468,8 @@ func.func @dot_general_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dim > 1, it is not converted. -// CHECK-LABEL: dot_general_per_axis_quantized_filter_multibatch -func.func @dot_general_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x3x2xi8>} : () -> tensor<3x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x1x3x!quant.uniform>, tensor<3x3x2x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> return %1 : tensor<3x1x2x!quant.uniform> @@ -460,8 +484,8 @@ func.func @dot_general_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has more than one contracting dimension, it is not converted. -// CHECK-LABEL: dot_general_per_axis_quantized_filter_with_multiple_contracting_dims -func.func @dot_general_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1, 2] x [2, 1] : (tensor<1x2x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x!quant.uniform> return %1 : tensor<1x1x!quant.uniform> @@ -470,3 +494,25 @@ func.func @dot_general_per_axis_quantized_filter_with_multiple_contracting_dims( // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.fully_connected // CHECK-NOT: tfl.batch_matmul + +// ----- + +// Test that a simple per-tensor quantized stablehlo.dot_general is properly +// fused with a subsequent requantize (qi32->qi8) op then legalized. +// Supports the following format: (lhs: qi8, rhs: qi8) -> result: qi32 + +// CHECK-LABEL: dot_general_full_integer +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x1024x!quant.uniform + func.func @dot_general_full_integer(%arg0: tensor<1x1024x!quant.uniform> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) { + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + +// CHECK-NOT: stablehlo.dot_general +// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform> +// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform:f32, 2.000000e+00>>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform:f32, 2.000000e+00>> +// CHECK: "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform>, tensor<3x!quant.uniform:f32, 2.000000e+00>>) -> tensor<1x3x!quant.uniform> +// CHECK-NOT: tfl.batch_matmul diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index a5286025463a52..587c971cdffaef 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -731,17 +731,18 @@ class ComposeUniformQuantizedConvolutionOp auto combined_scale_constant_op = cast( scale_combined_broadcast_in_dim_op.getOperand().getDefiningOp()); - SmallVector filter_scale_values; + SmallVector filter_scale_values; for (const auto combined_scale_value : combined_scale_constant_op.getValue() .cast() .getValues()) { - const float filter_scale_value = - combined_scale_value * input_inverse_scales_value; + // UniformQuantizedPerAxisType requires scales to have double dtype. + const double filter_scale_value = static_cast( + combined_scale_value * input_inverse_scales_value); filter_scale_values.emplace_back(filter_scale_value); } // Assumes it is symmetric. - SmallVector filter_zero_point_values( + SmallVector filter_zero_point_values( /*Size=*/filter_scale_values.size(), /*Value=*/0); // Use quantization dimension = 3 that corresponds to the output channel @@ -1083,15 +1084,17 @@ class ComposeUniformQuantizedDotGeneralOp // s1 * s2 auto merged_scale_constant_op = cast(multiply_op_second_operand.getDefiningOp()); - SmallVector filter_scale_values; + SmallVector filter_scale_values; for (const auto merged_scale : merged_scale_constant_op.getValue() .cast() .getValues()) { // (s1 * s2) * (1 / s1) = s2 - filter_scale_values.push_back(merged_scale * input_inverse_scale_value); + // UniformQuantizedPerAxisType requires scales to have double dtype. + filter_scale_values.push_back( + static_cast(merged_scale * input_inverse_scale_value)); } - SmallVector filter_zero_point_values( + SmallVector filter_zero_point_values( /*Size=*/filter_scale_values.size(), /*Value=*/0); const int quantization_dimension = GetFilterQuantizationDimension( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index f161bb3c90c3ae..d6ca92d5ca89db 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -3096,8 +3096,20 @@ class ConvertGatherOp : public OpConversionPattern { auto tf_gather_nd_result_type = RankedTensorType::get(transpose_params.canonicalized_output_shape, result_type.getElementType()); + + TF::CastOp cast_op = nullptr; + if (start_indices_type.getElementType().isUnsignedInteger(32)) { + cast_op = rewriter.create( + gather_op->getLoc(), + RankedTensorType::get(start_indices_type.getShape(), + rewriter.getI64Type()), + start_indices); + } + auto tf_gather_nd_op = rewriter.create( - gather_op->getLoc(), tf_gather_nd_result_type, operand, start_indices); + gather_op->getLoc(), tf_gather_nd_result_type, operand, + cast_op ? cast_op.getResult() : start_indices); + if (!need_transpose_after) { rewriter.replaceOp(gather_op, tf_gather_nd_op->getOpResults()); return success(); @@ -3386,9 +3398,6 @@ class ConvertIfOp : public OpConversionPattern { }; // Converts mhlo.pad to tf.PadV2 -// TODO: b/301438955 - This is redundant with the MHLO -> TFLite -// legalization and covers less usecases. We need to check with DarwiNN that -// this can be removed without breaking their workflow. Value ConvertPadOp(PatternRewriter& rewriter, Operation* old_op) { auto pad_op = cast(old_op); mlir::Location loc = pad_op.getLoc(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index fb2b2d6f068350..4aaf08a8686e5b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -49,24 +49,6 @@ cc_library( ], ) -cc_library( - name = "pad", - srcs = [ - "pad.cc", - ], - hdrs = [ - "pad.h", - ], - deps = [ - ":util", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@local_xla//xla/mlir_hlo", - ], -) - cc_library( name = "dot_general", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc deleted file mode 100644 index 9fd1fcb8402c51..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h" - -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { - -ConversionState BuildConversionState(mhlo::PadOp mhlo_pad, - ConversionPatternRewriter& rewriter) { - ConversionState state{ - /*.shlo_op=*/mhlo_pad.getOperation(), - /*.rewriter=*/rewriter, - /*.last_tf_op=*/nullptr, - }; - return state; -} - -// Converts the given StableHLO Pad operation to a chain of TFLite operations. -// -// StableHLO Pad allows dilating, padding and cropping its input, in that order. -// This can be implemented in TFLite as a sequence of these operations. Note -// that all operations do not always need to be called: if there is no dilation -// (resp. pad, crop) we do not need to add it to the chain. -// -// TFLite does not provide a crop operation, the StridedSlice one is used -// instead. -LogicalResult ConvertPadOp::matchAndRewrite( - mhlo::PadOp mhlo_pad, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { - // We don't need to match the pad op as we always know how to convert it. - ConversionState state = BuildConversionState(mhlo_pad, rewriter); - - // Dilate when interior padding is specified different from 0. - AddDilateOpIfRequired(state, mhlo_pad.getInteriorPadding(), - mhlo_pad.getPaddingValue(), - /*is_padding=*/true); - // Pad when padding has positive values. - AddPadOpIfRequired(state, mhlo_pad.getEdgePaddingLow(), - mhlo_pad.getEdgePaddingHigh(), mhlo_pad.getPaddingValue()); - // Crop when padding has negative values. - // - // Note that there is no crop operation in TFLite so we use the StridedSlice - // operation instead. - const DenseElementsAttr strides_data = CreateDenseElementsAttr( - state.rewriter, - llvm::SmallVector(state.GetOperandShape().size(), 1)); - AddStridedSliceOpIfRequired(state, mhlo_pad.getEdgePaddingLow(), - mhlo_pad.getEdgePaddingHigh(), strides_data); - - if (state.last_tf_op) { - rewriter.replaceOp(mhlo_pad, state.last_tf_op); - } else { - rewriter.replaceOp(mhlo_pad, mhlo_pad.getOperand()); - } - return success(); -} - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h deleted file mode 100644 index c0fa5017b69236..00000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_H_ - -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { - -class ConvertPadOp : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::PadOp mhlo_pad, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const final; -}; - -} // namespace odml -} // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc index 4432cec521b99d..c2f533776d0408 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc @@ -164,210 +164,6 @@ LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { return success(); } -Value ConversionState::GetOperand() const { - if (last_tf_op) { - return last_tf_op->getResult(0); - } - return hlo_op->getOperand(0); -} - -TensorType ConversionState::GetOperandTensorType() const { - if (last_tf_op) { - return last_tf_op->getResult(0).getType().cast(); - } - return hlo_op->getOperand(0).getType().cast(); -} - -llvm::ArrayRef ConversionState::GetOperandShape() const { - return GetOperandTensorType().getShape(); -} - -namespace { - -// Gets the dilation data for TFLite Dilate. -// -// Depending on the definition of the op we are trying to legalize, a dilation -// can be either seen as interior padding or as a scaling factor where: -// -// scaling_factor = interior_padding + 1 -// -// The is_padding parameter is used to take this difference into account. -llvm::SmallVector GetDilateData(const DenseElementsAttr& dilation, - const bool is_padding) { - llvm::SmallVector data; - for (const auto& v : dilation.getValues()) { - data.push_back(v.getSExtValue() + static_cast(is_padding)); - } - return data; -} - -} // namespace - -void AddDilateOpIfRequired(ConversionState& state, - const DenseElementsAttr& dilation, - const Value padding_value, const bool is_padding) { - const auto dilate_data = GetDilateData(dilation, is_padding); - if (absl::c_any_of(dilate_data, IsNot(1))) { - const TensorType output_type = state.ComputeResultTensorType( - [](int i, const auto& shape, const auto& dilate_data) { - if (shape[i] < 0) { - return shape[i]; - } - return shape[i] + (shape[i] - 1) * (dilate_data[i] - 1); - }, - dilate_data); - - auto dilate_tensor = AddConstantTensor(state, dilate_data); - auto tfl_dilate = state.rewriter.create( - state.hlo_op->getLoc(), output_type, state.GetOperand(), dilate_tensor, - padding_value); - - state.last_tf_op = tfl_dilate; - } -} - -namespace { - -// Gets the pad data for TFLite PadV2. -// -// StableHLO Pad allows negative values for cropping. This functions replaces -// negative values with 0. -llvm::SmallVector GetPadData( - const DenseElementsAttr& edge_padding_low, - const DenseElementsAttr& edge_padding_high) { - llvm::SmallVector data; - auto low_values = edge_padding_low.getValues(); - auto high_values = edge_padding_high.getValues(); - for (int i = 0; i < edge_padding_low.getNumElements(); ++i) { - const int64_t pad_low = low_values[i].getSExtValue(); - const int64_t pad_high = high_values[i].getSExtValue(); - data.push_back(pad_low < 0 ? 0 : pad_low); - data.push_back(pad_high < 0 ? 0 : pad_high); - } - return data; -} - -template -void AddPadOpIfRequiredImpl(ConversionState& state, const Container& pad_data, - const Value padding_value) { - if (absl::c_any_of(pad_data, IsNot(0))) { - const TensorType output_type = state.ComputeResultTensorType( - [](int i, const auto& shape, const auto& pad) { - if (shape[i] < 0) { - return shape[i]; - } - return shape[i] + pad[2 * i] + pad[2 * i + 1]; - }, - pad_data); - - auto pad_tensor = AddConstantTensor( - state, pad_data, - {static_cast(state.GetOperandShape().size()), 2}); - auto tfl_pad = state.rewriter.create( - state.hlo_op->getLoc(), output_type, state.GetOperand(), pad_tensor, - padding_value); - - state.last_tf_op = tfl_pad; - } -} - -} // namespace - -void AddPadOpIfRequired(ConversionState& state, - const DenseElementsAttr& edge_padding_low, - const DenseElementsAttr& edge_padding_high, - const Value padding_value) { - AddPadOpIfRequiredImpl(state, GetPadData(edge_padding_low, edge_padding_high), - padding_value); -} - -namespace { - -// Holds the data needed to generate a TFLite StridedSlice operation. -struct StridedSliceData { - llvm::SmallVector low; - llvm::SmallVector high; - llvm::SmallVector strides; - int32_t begin_mask = 0; - int32_t end_mask = 0; - - void resize(const size_t size) { - low.resize(size); - high.resize(size); - strides.resize(size); - } -}; - -// Updates the strided slice data with the given values for the `i`th element. -// -// Warning: this expects the data internal buffers to have at least i+1 -// elements. -void AppendDataDim(StridedSliceData& data, const int i, const APInt& low, - const APInt& high, const APInt& stride) { - const int64_t pad_low = low.getSExtValue(); - const int64_t pad_high = high.getSExtValue(); - if (pad_low >= 0) { - data.begin_mask |= 1 << i; - data.low[i] = 0; - } else { - data.low[i] = -pad_low; - } - if (pad_high >= 0) { - data.end_mask |= 1 << i; - data.high[i] = 0; - } else { - data.high[i] = pad_high; - } - data.strides[i] = stride.getSExtValue(); -} - -// Gets the data needed to generate a TFLite StridedSlice operation. -StridedSliceData GetStridedSliceData(const DenseElementsAttr& edge_padding_low, - const DenseElementsAttr& edge_padding_high, - const DenseElementsAttr& strides) { - StridedSliceData data; - data.resize(edge_padding_low.getNumElements()); - const auto low_values = edge_padding_low.getValues(); - const auto high_values = edge_padding_high.getValues(); - const auto stride_values = strides.getValues(); - for (int i = 0; i < edge_padding_low.getNumElements(); ++i) { - AppendDataDim(data, i, low_values[i], high_values[i], stride_values[i]); - } - return data; -} - -void AddStridedSliceOpIfRequiredImpl( - ConversionState& state, const StridedSliceData& strided_slice_data) { - if (absl::c_any_of(strided_slice_data.low, IsNot(0)) || - absl::c_any_of(strided_slice_data.high, IsNot(0)) || - absl::c_any_of(strided_slice_data.strides, IsNot(1))) { - const TensorType output_type = state.ComputeResultTensorType( - [](int i, const auto& shape, const auto& high, const auto& low, - const auto& strides) { - if (shape[i] < 0) { - return shape[i]; - } - return (shape[i] + high[i] - low[i]) / strides[i]; - }, - strided_slice_data.high, strided_slice_data.low, - strided_slice_data.strides); - - auto crop_begin_tensor = AddConstantTensor(state, strided_slice_data.low); - auto crop_end_tensor = AddConstantTensor(state, strided_slice_data.high); - auto crop_strides_tensor = - AddConstantTensor(state, strided_slice_data.strides); - auto tfl_crop = state.rewriter.create( - state.hlo_op->getLoc(), output_type, state.GetOperand(), - crop_begin_tensor, crop_end_tensor, crop_strides_tensor, - strided_slice_data.begin_mask, strided_slice_data.end_mask, 0, 0, 0, - false); - - state.last_tf_op = tfl_crop; - } -} - -} // namespace - bool NeedsReformatTypeAndPermutation(int batch_dim, int feature_dim, int spatial_dim_start, int default_batch_dim, @@ -426,15 +222,6 @@ Value InsertTranspose(Value value, int batch_dim, int feature_dim, permutation); } -void AddStridedSliceOpIfRequired(ConversionState& state, - const DenseElementsAttr& edge_padding_low, - const DenseElementsAttr& edge_padding_high, - const DenseElementsAttr& strides) { - StridedSliceData strided_slice_data = - GetStridedSliceData(edge_padding_low, edge_padding_high, strides); - AddStridedSliceOpIfRequiredImpl(state, strided_slice_data); -} - Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { IntegerType new_ele_type = rewriter.getIntegerType(32); if (auto shaped_type = val.getType().dyn_cast()) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h index 442161ade171f7..c58fdaa76a78d2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h @@ -139,137 +139,6 @@ LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { template <> LogicalResult MatchBinaryReduceFunction(mlir::Region& function); -// Concentrates the data needed to substitute StableHLO operations with TFLite -// ones. -struct ConversionState { - Operation* hlo_op; - ConversionPatternRewriter& rewriter; - Operation* last_tf_op; - - // Returns the main operand of a NEW op to add to the conversion chain. - // - // This is generally the result of the last op that was added to the chain. - Value GetOperand() const; - - // Returns the type of the operand of a NEW op to add to the conversion chain. - // - // This is generally the type of the result of the last op that was added to - // the chain. - TensorType GetOperandTensorType() const; - - llvm::ArrayRef GetOperandShape() const; - - // Computes a new shape from the current operand shape. - // - // - The args are containers that are indexable using operator[]. - // - The callback must be callable have a signature that is: - // `int64_t (int idx, shape, decltype(args)...)` - // - // The callback is called for each element of the operand shape with the - // index of the current loop iteration, the shape and args. - template - llvm::SmallVector ComputeResultShape(F&& callback, - Containers&&... args) const { - llvm::ArrayRef shape = GetOperandShape(); - llvm::SmallVector res; - for (int i = 0; i < shape.size(); ++i) { - if (shape[i] < 0) { - res.push_back(shape[i]); - } else { - res.push_back(callback(i, shape, args...)); - } - } - return res; - } - - template - TensorType ComputeResultTensorType(F&& callback, Containers&&... args) const { - const llvm::SmallVector shape = ComputeResultShape( - static_cast(callback), static_cast(args)...); - return GetOperandTensorType().cloneWith( - shape, GetOperandTensorType().getElementType()); - } -}; - -// Gets the Type associated to type T from the builder. -template -Type GetElementType(OpBuilder& builder); - -#define GET_ELEMENT_TYPE_SPECIALISATION(TYPE, NAME) \ - template <> \ - inline Type GetElementType(OpBuilder & builder) { \ - return builder.get##NAME##Type(); \ - } - -GET_ELEMENT_TYPE_SPECIALISATION(int32_t, I32); -GET_ELEMENT_TYPE_SPECIALISATION(int64_t, I64); - -// Create a DenseElementsAttr from given shape and data. -template > -DenseElementsAttr CreateDenseElementsAttr(OpBuilder& builder, const Data& data, - const Shape& shape = Shape()) { - llvm::SmallVector attr_shape(shape.begin(), shape.end()); - if (attr_shape.empty()) { - attr_shape.push_back(static_cast(data.size())); - } - const Type attr_type = GetElementType(builder); - return DenseElementsAttr::get(RankedTensorType::get(attr_shape, attr_type), - ArrayRef(data)); -} - -// Adds a constant tensor to the conversion chain. -template > -auto AddConstantTensor(ConversionState& state, const Data& data, - const Shape& shape = Shape()) { - const DenseElementsAttr attr = - CreateDenseElementsAttr(state.rewriter, data, shape); - return state.rewriter.create(state.hlo_op->getLoc(), attr); -} - -// Builds a callable object that checks that its argument is not the given -// `value`. -template -auto IsNot(T value) { - return [value](auto v) { return v != value; }; -} - -// Adds a TFLite Dilate operation to the conversion chain. -// -// If the given parameters would end with the identity operation, this does not -// add anything to the chain. -// -// Depending on the definition of the op we are trying to legalize, a dilation -// can be either seen as interior padding or as a scaling factor where: -// -// scaling_factor = interior_padding + 1 -// -// The is_padding parameter is used to take this difference into account. -void AddDilateOpIfRequired(ConversionState& state, - const DenseElementsAttr& dilation, - Value padding_value, bool is_padding); - -// Adds a TFLite PadV2 operation to the conversion chain. -// -// If the given parameters would end with the identity operation, this does not -// add anything to the chain. -void AddPadOpIfRequired(ConversionState& state, - const DenseElementsAttr& edge_padding_low, - const DenseElementsAttr& edge_padding_high, - Value padding_value); - -// Adds a TFLite StridedSlice operation to the conversion chain. -// -// This overload is used to legalize a crop operation in TFLite. As such, the -// begin and end specifications of the strided slice are computed from the -// negative values in the padding parameters. -// -// If the given parameters would end with the identity operation, this does not -// add anything to the chain. -void AddStridedSliceOpIfRequired(ConversionState& state, - const DenseElementsAttr& edge_padding_low, - const DenseElementsAttr& edge_padding_high, - const DenseElementsAttr& strides); - // Util that casts 'val' to Int32 by adding a tfl cast Op. Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter); } // namespace odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc index 858fe15a7f492a..3bb9eddbfa5021 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc @@ -125,6 +125,16 @@ void StablehloToTflPass::runOnOperation() { continue; } + if (attr.isa<::mlir::DenseI64ArrayAttr>()) { + auto array_attr = attr.dyn_cast(); + auto start = fbb->StartVector(key); + for (auto int_value : array_attr.asArrayRef()) { + fbb->Add(int_value); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + continue; + } + if (attr.isa<::mlir::StringAttr>()) { fbb->String(key, attr.dyn_cast().data()); continue; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc index ec708f70724c84..5e4f79f18ce503 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc @@ -75,7 +75,8 @@ class TflToStablehloPass } llvm::SmallVector ReadAttr(const flexbuffers::Map& m, - Builder* builder) { + Builder* builder, + std::string op_name) { llvm::SmallVector attrs; const auto& keys = m.Keys(); for (size_t i = 0; i < keys.size(); ++i) { @@ -102,10 +103,19 @@ class TflToStablehloPass } else { shape.push_back(vec.size()); } - RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( - shape, builder->getIntegerType(64)); - auto named_attr = - builder->getNamedAttr(key, DenseIntElementsAttr::get(ty, vec)); + Attribute value; + if (op_name == "stablehlo.broadcast" || + op_name == "stablehlo.dynamic_slice" || + op_name == "stablehlo.fft" || op_name == "stablehlo.pad" || + op_name == "stablehlo.reverse" || op_name == "stablehlo.slice" || + op_name == "stablehlo.transpose") { + value = builder->getDenseI64ArrayAttr(vec); + } else { + RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( + shape, builder->getIntegerType(64)); + value = DenseIntElementsAttr::get(ty, vec); + } + auto named_attr = builder->getNamedAttr(key, value); attrs.push_back(named_attr); break; } @@ -181,7 +191,8 @@ void TflToStablehloPass::runOnOperation() { flexbuffers::GetRoot(option_buf, custom_op.getCustomOption().getValue().size()) .AsMap(); - auto attr = ReadAttr(flex_buffer_map, &builder); + auto attr = + ReadAttr(flex_buffer_map, &builder, custom_op.getCustomCode().str()); OperationState op_state(custom_op.getLoc(), custom_op.getCustomCode().str()); op_state.addOperands(custom_op.getOperands()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index e50cb2dad9f4c0..6c07c0c0e4b8d2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep @@ -57,7 +56,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { MLIRContext& context = getContext(); RewritePatternSet patterns(&getContext()); // Add new conversion patterns here. - patterns.add(&context); + patterns.add(&context); populateWithGenerated(patterns); ConversionTarget target(context); @@ -66,8 +65,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { target.addDynamicallyLegalOp(IsCustomCallLegal); // Converted MHLO ops should be marked illegal here. // TODO: b/304003568 - Add TF_TransposeOp folding logic to tflite. - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { getOperation().emitError("mhlo to TFLite legalization failed."); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index 18070fe59134e3..3ba5ad97ad579e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" #define DEBUG_TYPE "uniform-quantized-stablehlo-to-tfl" @@ -46,6 +47,13 @@ namespace mlir { namespace odml { namespace { +// TODO: b/311029361: Add e2e test for verifying this legalization once +// StableHLO Quantizer API migration is complete. + +using ::mlir::quant::IsI32F32UniformQuantizedType; +using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; +using ::mlir::quant::IsI8F32UniformQuantizedType; +using ::mlir::quant::IsSupportedByTfliteQuantizeOrDequantizeOps; using ::mlir::quant::QuantizedType; using ::mlir::quant::UniformQuantizedPerAxisType; using ::mlir::quant::UniformQuantizedType; @@ -60,95 +68,164 @@ class UniformQuantizedStablehloToTflPass void runOnOperation() override; }; -// Determines whether the storage type of a quantized type is supported by -// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. -bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { - if ((storage_type.isSigned() && - !(storage_type.getWidth() == 8 || storage_type.getWidth() == 16)) || - (!storage_type.isSigned() && storage_type.getWidth() != 8)) { - LLVM_DEBUG(llvm::dbgs() - << "Uniform quantize / dequantize op only supports ui8, i8 or " - "i16 for the storage type of uniform quantized type. Got: " - << storage_type << ".\n"); - return false; - } - return true; -} - -// Returns true iff the storage type of `quantized_type` is 8-bit integer. -bool IsStorageTypeI8(QuantizedType quantized_type) { - const Type storage_type = quantized_type.getStorageType(); - return storage_type.isInteger(/*width=*/8); +// Bias scales for matmul-like ops should be input scale * filter scale. Here it +// is assumed that the input is per-tensor quantized and filter is per-channel +// quantized. +SmallVector GetBiasScales(const double input_scale, + const ArrayRef filter_scales) { + SmallVector bias_scales; + absl::c_transform(filter_scales, std::back_inserter(bias_scales), + [input_scale](const double filter_scale) -> double { + return filter_scale * input_scale; + }); + return bias_scales; } -// Returns true iff the expressed type of `quantized_type` is f32. -bool IsExpressedTypeF32(QuantizedType quantized_type) { - const Type expressed_type = quantized_type.getExpressedType(); - return expressed_type.isa(); +// Returns a bias scale for matmul-like ops. Here it is assumed that both input +// and filter are per-tensor quantized. +double GetBiasScale(const double input_scale, const double filter_scale) { + return filter_scale * input_scale; } -// Returns true iff `type` is a uniform quantized type whose storage type is -// 8-bit integer and expressed type is f32. -bool IsI8F32UniformQuantizedType(const Type type) { - auto quantized_type = type.dyn_cast_or_null(); - if (!quantized_type) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized type. Got: " << type << ".\n"); - return false; +// Creates a new `tfl.qconst` op for the quantized filter. Transposes the +// filter value from [i, o] -> [o, i]. This is because we assume `[i, o]` +// format for `stablehlo.dot_general` (i.e. contracting dimension == 1) +// whereas `tfl.fully_connected` accepts an OI format. +TFL::QConstOp CreateTflConstOpForFilter( + stablehlo::ConstantOp filter_constant_op, PatternRewriter& rewriter, + bool is_per_axis) { + const auto filter_values = filter_constant_op.getValue() + .cast() + .getValues(); + + ArrayRef filter_shape = + filter_constant_op.getType().cast().getShape(); + + // Reverse the shapes. This makes sense, assuming that the filter tensor has a + // rank of 2 (no batch dimension). + SmallVector new_filter_shape(filter_shape.rbegin(), + filter_shape.rend()); + + // Construct the value array of transposed filter. Assumes 2D matrix. + SmallVector new_filter_values(filter_values.size(), /*Value=*/0); + for (int i = 0; i < filter_shape[0]; ++i) { + for (int j = 0; j < filter_shape[1]; ++j) { + const int old_idx = i * filter_shape[1] + j; + const int new_idx = j * filter_shape[0] + i; + new_filter_values[new_idx] = filter_values[old_idx]; + } } - if (!IsStorageTypeI8(quantized_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " - << quantized_type << ".\n"); - return false; - } + auto new_filter_value_attr_type = RankedTensorType::getChecked( + filter_constant_op.getLoc(), new_filter_shape, + /*elementType=*/rewriter.getI8Type()); - if (!IsExpressedTypeF32(quantized_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " - << quantized_type << ".\n"); - return false; + Type new_filter_quantized_type; + + if (is_per_axis) { + auto filter_quantized_type = filter_constant_op.getResult() + .getType() + .cast() + .getElementType() + .cast(); + + new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( + filter_constant_op.getLoc(), /*flags=*/true, + /*storageType=*/filter_quantized_type.getStorageType(), + /*expressedType=*/filter_quantized_type.getExpressedType(), + /*scales=*/filter_quantized_type.getScales(), + /*zeroPoints=*/filter_quantized_type.getZeroPoints(), + /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); + } else { + auto filter_quantized_type = filter_constant_op.getResult() + .getType() + .cast() + .getElementType() + .cast(); + new_filter_quantized_type = UniformQuantizedType::getChecked( + filter_constant_op.getLoc(), /*flags=*/true, + /*storageType=*/filter_quantized_type.getStorageType(), + /*expressedType=*/filter_quantized_type.getExpressedType(), + /*scale=*/filter_quantized_type.getScale(), + /*zeroPoint=*/filter_quantized_type.getZeroPoint(), + /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); } - return true; + // Required because the quantized dimension is changed from 3 -> 0. + auto new_filter_result_type = RankedTensorType::getChecked( + filter_constant_op.getLoc(), /*shape=*/new_filter_shape, + /*type=*/new_filter_quantized_type); + + auto new_filter_constant_value_attr = + DenseIntElementsAttr::get(new_filter_value_attr_type, new_filter_values); + return rewriter.create( + filter_constant_op.getLoc(), + /*output=*/TypeAttr::get(new_filter_result_type), + /*value=*/new_filter_constant_value_attr); } -// Returns true iff `type` is a uniform quantized per-axis (per-channel) type -// whose storage type is 8-bit integer and expressed type is f32. -bool IsI8F32UniformQuantizedPerAxisType(const Type type) { - auto quantized_per_axis_type = - type.dyn_cast_or_null(); - if (!quantized_per_axis_type) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized type. Got: " << type << ".\n"); - return false; - } +// Creates a new `tfl.qconst` op for the bias. The bias values are 0s, because +// this bias a dummy bias (note that bias fusion is not considered for this +// transformation). The quantization scale for the bias is input scale * +// filter scale. `filter_const_op` is used to retrieve the filter scales and +// the size of the bias constant. +// TODO - b/309896242: Support bias fusion legalization. +TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, + const double input_scale, + TFL::QConstOp filter_const_op, + PatternRewriter& rewriter, + bool is_per_axis) { + const ArrayRef filter_shape = + filter_const_op.getResult().getType().getShape(); + + Type bias_quantized_type; + if (is_per_axis) { + const auto filter_quantized_element_type = + filter_const_op.getResult() + .getType() + .getElementType() + .cast(); - if (!IsStorageTypeI8(quantized_per_axis_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " - << quantized_per_axis_type << ".\n"); - return false; - } + // The storage type is i32 for bias, which is the precision used for + // accumulation. + bias_quantized_type = UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), + /*expressedType=*/rewriter.getF32Type(), /*scales=*/ + GetBiasScales(input_scale, filter_quantized_element_type.getScales()), + /*zeroPoints=*/filter_quantized_element_type.getZeroPoints(), + /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); + } else { + const auto filter_quantized_element_type = + filter_const_op.getResult() + .getType() + .getElementType() + .cast(); - if (!IsExpressedTypeF32(quantized_per_axis_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " - << quantized_per_axis_type << ".\n"); - return false; + // The storage type is i32 for bias, which is the precision used for + // accumulation. + bias_quantized_type = UniformQuantizedType::getChecked( + loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), + /*expressedType=*/rewriter.getF32Type(), /*scale=*/ + GetBiasScale(input_scale, filter_quantized_element_type.getScale()), + /*zeroPoint=*/filter_quantized_element_type.getZeroPoint(), + /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); } - return true; -} + SmallVector bias_shape = {filter_shape[0]}; + auto bias_type = + RankedTensorType::getChecked(loc, bias_shape, bias_quantized_type); -// Bias scales for matmul-like ops should be input scale * filter scale. Here it -// is assumed that the input is per-tensor quantized and filter is per-channel -// quantized. -SmallVector GetBiasScales(const double input_scale, - const ArrayRef filter_scales) { - SmallVector bias_scales; - absl::c_transform(filter_scales, std::back_inserter(bias_scales), - [input_scale](const double filter_scale) -> double { - return filter_scale * input_scale; - }); - return bias_scales; + auto bias_value_type = RankedTensorType::getChecked( + loc, std::move(bias_shape), rewriter.getI32Type()); + auto bias_value = DenseIntElementsAttr::get( + bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + + return rewriter.create( + loc, /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); } // stablehlo.uniform_quantize -> tfl.quantize @@ -163,10 +240,11 @@ class RewriteUniformQuantizeOp LogicalResult match(stablehlo::UniformQuantizeOp op) const override { const Type input_element_type = op.getOperand().getType().cast().getElementType(); - if (!input_element_type.isa()) { - LLVM_DEBUG(llvm::dbgs() - << "Uniform quantize op's input should be a float type. Got: " - << input_element_type << ".\n"); + if (!(input_element_type.isa() || + IsI32F32UniformQuantizedType(input_element_type))) { + LLVM_DEBUG(llvm::dbgs() << "Uniform quantize op's input should be a " + "float type or int32. Got: " + << input_element_type << ".\n"); return failure(); } @@ -257,7 +335,7 @@ class RewriteUniformDequantizeOp // * Not a depthwise convolution. // * Does not consider bias add fusion. // TODO: b/294771704 - Support bias quantization. -class RewriteQuantizedConvolutionOp +class RewriteUpstreamQuantizedConvolutionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -654,7 +732,7 @@ class RewriteQuantizedConvolutionOp // // TODO: b/293650675 - Relax the conversion condition to support dot_general in // general. -class RewriteFullIntegerQuantizedDotGeneralOp +class RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -662,7 +740,7 @@ class RewriteFullIntegerQuantizedDotGeneralOp static LogicalResult MatchLhs( Value lhs, stablehlo::DotDimensionNumbersAttr dimension_numbers) { auto lhs_type = lhs.getType().cast(); - if (!(IsI8F32UniformQuantizedType(lhs_type.getElementType()))) { + if (!IsI8F32UniformQuantizedType(lhs_type.getElementType())) { LLVM_DEBUG(llvm::dbgs() << "Expected a per-tensor uniform " "quantized (i8->f32) input for dot_general. Got: " @@ -704,7 +782,7 @@ class RewriteFullIntegerQuantizedDotGeneralOp } auto rhs_type = rhs.getType().cast(); - if (!(IsI8F32UniformQuantizedType(rhs_type.getElementType()))) { + if (!IsI8F32UniformQuantizedType(rhs_type.getElementType())) { LLVM_DEBUG(llvm::dbgs() << "Expected a per-tensor uniform " "quantized (i8->f32) weight for dot_general. Got: " @@ -714,6 +792,19 @@ class RewriteFullIntegerQuantizedDotGeneralOp return success(); } + static LogicalResult MatchOutput( + Value output, stablehlo::DotDimensionNumbersAttr dimension_numbers) { + auto output_type = output.getType().cast(); + if (!IsI8F32UniformQuantizedType(output_type.getElementType())) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a per-tensor uniform " + "quantized (i8->f32) output for dot_general. Got: " + << output_type << "\n"); + return failure(); + } + return success(); + } + LogicalResult match(stablehlo::DotGeneralOp op) const override { stablehlo::DotDimensionNumbersAttr dimension_numbers = op.getDotDimensionNumbers(); @@ -746,6 +837,12 @@ class RewriteFullIntegerQuantizedDotGeneralOp return failure(); } + if (failed(MatchOutput(op.getResult(), dimension_numbers))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized dot_general.\n"); + return failure(); + } + return success(); } @@ -819,7 +916,7 @@ class RewriteFullIntegerQuantizedDotGeneralOp // `RewriteFullIntegerQuantizedDotGeneralOp`. // TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands // is not specified in the StableHLO dialect. Update the spec to allow this. -class RewriteQuantizedDotGeneralOpToTflFullyConnectedOp +class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -867,15 +964,17 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOp cast(op.getOperand(1).getDefiningOp()); TFL::QConstOp new_filter_constant_op = - CreateTflConstOpForFilter(filter_constant_op, rewriter); + CreateTflConstOpForFilter(filter_constant_op, rewriter, + /*is_per_axis=*/true); const Value input_value = op.getOperand(0); const double input_scale = input_value.getType() .cast() .getElementType() .cast() .getScale(); - TFL::QConstOp bias_constant_op = CreateTflConstOpForBias( - op.getLoc(), input_scale, new_filter_constant_op, rewriter); + TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( + op.getLoc(), input_scale, new_filter_constant_op, rewriter, + /*is_per_axis=*/true); const Value result_value = op.getResult(); // Set to `nullptr` because this attribute only matters when the input is @@ -962,106 +1061,208 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOp return success(); } +}; - // Creates a new `tfl.qconst` op for the quantized filter. Transposes the - // filter value from [i, o] -> [o, i]. This is because we assume `[i, o]` - // format for `stablehlo.dot_general` (i.e. contracting dimension == 1) - // whereas `tfl.fully_connected` accepts an OI format. - TFL::QConstOp CreateTflConstOpForFilter( - stablehlo::ConstantOp filter_constant_op, - PatternRewriter& rewriter) const { - const auto filter_values = filter_constant_op.getValue() - .cast() - .getValues(); +// Rewrites `stablehlo.dot_general` to `tfl.fully_connected` or +// `tfl.batch_matmul` when it accepts uniform quantized tensors. +// +// Conditions for `tfl.fully_connected` conversion: +// * Input and output tensors are per-tensor uniform quantized (i8->f32) +// tensors. +// * The filter tensor is constant a per-tensor uniform quantized (i8->f32) +// tensor. The quantization dimension should be 1 (the non-contracting +// dimension). +// * The input tensor's rank is either 2 or 3. The last dimension of the input +// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. +// * The filter tensor's rank is 2. The contracting dimension should be the +// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. +// * Does not consider activation fusion. +// * Does not consider bias add fusion. +// TODO: b/580909703 - Include conversion conditions for `tfl.batch_matmul` op. +// +// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands +// is not specified in the StableHLO dialect. Update the spec to allow this. +class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - ArrayRef filter_shape = - filter_constant_op.getType().cast().getShape(); + public: + LogicalResult match(stablehlo::DotGeneralOp op) const override { + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = + op.getDotDimensionNumbers(); + if (const int num_rhs_contracting_dims = + dot_dimension_nums.getRhsContractingDimensions().size(); + num_rhs_contracting_dims != 1) { + LLVM_DEBUG(llvm::dbgs() + << "Expected number of contracting dimensions to be 1. Got: " + << num_rhs_contracting_dims << ".\n"); + return failure(); + } - // Reverse the shapes. This makes sense because it assumes that the filter - // tensor has rank of 2 (no batch dimension). - SmallVector new_filter_shape(filter_shape.rbegin(), - filter_shape.rend()); + if (failed(MatchInput(op.getOperand(0)))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match input for quantized dot_general op.\n"); + return failure(); + } - // Construct the value array of transposed filter. Assumes 2D matrix. - SmallVector new_filter_values(filter_values.size(), /*Value=*/0); - for (int i = 0; i < filter_shape[0]; ++i) { - for (int j = 0; j < filter_shape[1]; ++j) { - const int old_idx = i * filter_shape[1] + j; - const int new_idx = j * filter_shape[0] + i; - new_filter_values[new_idx] = filter_values[old_idx]; - } + if (failed(MatchFilter(op.getOperand(1)))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match filter for quantized dot_general op.\n"); + return failure(); } - auto new_filter_value_attr_type = RankedTensorType::getChecked( - filter_constant_op.getLoc(), new_filter_shape, - /*elementType=*/rewriter.getI8Type()); + if (failed(MatchOutput(op.getResult()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized dot_general op.\n"); + return failure(); + } - auto filter_quantized_type = filter_constant_op.getResult() - .getType() - .cast() - .getElementType() - .cast(); + if (failed(MatchUsers(op.getResult()))) { + LLVM_DEBUG(llvm::dbgs() << "Failed to match subsequent requantize for " + "quantized dot_general op.\n"); + return failure(); + } - auto new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( - filter_constant_op.getLoc(), /*flags=*/true, - /*storageType=*/filter_quantized_type.getStorageType(), - /*expressedType=*/filter_quantized_type.getExpressedType(), - /*scales=*/filter_quantized_type.getScales(), - /*zeroPoints=*/filter_quantized_type.getZeroPoints(), - /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + return success(); + } - // Required because the quantized dimension is changed from 3 -> 0. - auto new_filter_result_type = RankedTensorType::getChecked( - filter_constant_op.getLoc(), /*shape=*/new_filter_shape, - /*type=*/new_filter_quantized_type); + void rewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + // Create the new filter constant - transpose filter value + // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for + // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas + // `tfl.fully_connected` accepts an OI format. + auto filter_constant_op = + cast(op.getOperand(1).getDefiningOp()); - auto new_filter_constant_value_attr = DenseIntElementsAttr::get( - new_filter_value_attr_type, new_filter_values); - return rewriter.create( - filter_constant_op.getLoc(), - /*output=*/TypeAttr::get(new_filter_result_type), - /*value=*/new_filter_constant_value_attr); + TFL::QConstOp new_filter_constant_op = CreateTflConstOpForFilter( + filter_constant_op, rewriter, /*is_per_axis=*/false); + const Value input_value = op.getOperand(0); + const double input_scale = input_value.getType() + .cast() + .getElementType() + .cast() + .getScale(); + TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( + op.getLoc(), input_scale, new_filter_constant_op, rewriter, + /*is_per_axis=*/false); + + auto output_op = op.getResult().getDefiningOp(); + Operation* requantize_op = *output_op->getResult(0).getUsers().begin(); + Operation* dequantize_op = *requantize_op->getResult(0).getUsers().begin(); + + // Set to `nullptr` because this attribute only matters when the input is + // dynamic-range quantized. + const BoolAttr asymmetric_quantize_inputs = nullptr; + auto tfl_fully_connected_op = rewriter.create( + op.getLoc(), + /*output=*/ + requantize_op->getResult(0).getType(), // result_value.getType(), + /*input=*/input_value, /*filter=*/new_filter_constant_op.getResult(), + /*bias=*/bias_constant_op.getResult(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + asymmetric_quantize_inputs); + + auto tfl_dequantize_op = rewriter.create( + op.getLoc(), dequantize_op->getResult(0).getType(), + tfl_fully_connected_op->getResult(0)); + + rewriter.replaceAllUsesWith(dequantize_op->getResult(0), + tfl_dequantize_op->getResult(0)); + + rewriter.replaceAllUsesWith(op.getResult(), + tfl_fully_connected_op.getResult(0)); + + rewriter.eraseOp(op); } - // Creates a new `tfl.qconst` op for the bias. The bias values are 0s, because - // this bias a dummy bias (note that bias fusion is not considered for this - // transformation). The quantization scale for the bias is input scale * - // filter scale. `filter_const_op` is used to retrieve the filter scales and - // the size of the bias constant. - TFL::QConstOp CreateTflConstOpForBias(const Location loc, - const double input_scale, - TFL::QConstOp filter_const_op, - PatternRewriter& rewriter) const { - const ArrayRef filter_shape = - filter_const_op.getResult().getType().getShape(); - const auto filter_quantized_element_type = - filter_const_op.getResult() - .getType() - .getElementType() - .cast(); + private: + static LogicalResult MatchInput(Value input) { + auto input_type = input.getType().cast(); + if (!input_type.hasRank() || + !(input_type.getRank() == 2 || input_type.getRank() == 3)) { + LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " + << input_type << ".\n"); + return failure(); + } - // The storage type is i32 for bias, which is the precision used for - // accumulation. - auto bias_quantized_type = UniformQuantizedPerAxisType::getChecked( - loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), - /*expressedType=*/rewriter.getF32Type(), /*scales=*/ - GetBiasScales(input_scale, filter_quantized_element_type.getScales()), - /*zeroPoints=*/filter_quantized_element_type.getZeroPoints(), - /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + if (const auto input_element_type = input_type.getElementType(); + !IsI8F32UniformQuantizedType(input_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected an i8->f32 uniform quantized type. Got: " + << input_element_type << ".\n"); + return failure(); + } - SmallVector bias_shape = {filter_shape[0]}; - auto bias_type = - RankedTensorType::getChecked(loc, bias_shape, bias_quantized_type); + return success(); + } - auto bias_value_type = RankedTensorType::getChecked( - loc, std::move(bias_shape), rewriter.getI32Type()); - auto bias_value = DenseIntElementsAttr::get( - bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + static LogicalResult MatchFilter(Value filter) { + auto filter_type = filter.getType().cast(); + if (!filter_type.hasRank() || filter_type.getRank() != 2) { + LLVM_DEBUG(llvm::dbgs() + << "Filter tensor expected to have a tensor rank of 2. Got: " + << filter_type << ".\n"); + return failure(); + } - return rewriter.create( - loc, /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); + const Type filter_element_type = filter_type.getElementType(); + if (!IsI8F32UniformQuantizedType(filter_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized (i8->f32) type. Got: " + << filter_element_type << "\n"); + return failure(); + } + + if (Operation* filter_op = filter.getDefiningOp(); + filter_op == nullptr || !isa(filter_op)) { + LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + return failure(); + } + + return success(); + } + + static LogicalResult MatchOutput(Value output) { + const Type output_element_type = + output.getType().cast().getElementType(); + if (!IsI32F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized (i32->f32) type. Got: " + << output_element_type << ".\n"); + return failure(); + } + return success(); + } + + static LogicalResult MatchUsers(Value output) { + auto output_op = output.getDefiningOp(); + + if (!output_op->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() << "Expected output to be used only once.\n"); + return failure(); + } + // TODO: b/309896242 - Add support for fused op case. + if (Operation* requantize_op = dyn_cast_or_null( + *output_op->getResult(0).getUsers().begin())) { + const Type requantize_element_type = requantize_op->getResult(0) + .getType() + .cast() + .getElementType(); + if (!IsI8F32UniformQuantizedType(requantize_element_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected a quantize (i8->f32) type. Got: " + << requantize_element_type << ".\n"); + return failure(); + } + if (!isa( + *requantize_op->getResult(0).getUsers().begin())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a dequantize type.\n"); + return failure(); + } + } + return success(); } }; @@ -1071,9 +1272,11 @@ void UniformQuantizedStablehloToTflPass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); + RewriteUpstreamQuantizedConvolutionOp, + RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp, + RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp, + RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp>( + &ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 17b724051cdede..77f634edb94768 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -344,3 +344,32 @@ func.func @trivial_dynamic_update_slice_wrong_update_shape(%arg0: tensor<2x7x14x // CHECK: "tfl.dynamic_update_slice" func.return %1 : tensor<2x7x14xf32> } + +// CHECK-LABEL: OptimizeTranposeWithRank7orMoreEffectiveRank6 +func.func @OptimizeTranposeWithRank7orMoreEffectiveRank6(%arg0: tensor<7x6x5x4x3x2x1xf32> ) -> (tensor<1x2x3x4x5x6x7xf32>) { + %cst = arith.constant dense<[6, 5, 4, 3, 2, 1, 0]> : tensor<7xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<7x6x5x4x3x2x1xf32>, tensor<7xi32>) -> tensor<1x2x3x4x5x6x7xf32> + return %0 : tensor<1x2x3x4x5x6x7xf32> + // CHECK-DAG: %cst = arith.constant dense<[7, 6, 5, 4, 3, 2]> : tensor<6xi32> + // CHECK-DAG: %cst_0 = arith.constant dense<[5, 4, 3, 2, 1, 0]> : tensor<6xi32> + // CHECK-DAG: %cst_1 = arith.constant dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<7x6x5x4x3x2x1xf32>, tensor<6xi32>) -> tensor<7x6x5x4x3x2xf32> + // CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor<7x6x5x4x3x2xf32>, tensor<6xi32>) -> tensor<2x3x4x5x6x7xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_1) : (tensor<2x3x4x5x6x7xf32>, tensor<7xi32>) -> tensor<1x2x3x4x5x6x7xf32> + // CHECK: return %2 +} + +// CHECK-LABEL: OptimizeTranposeWithRank7orMoreEffectiveRank4 +func.func @OptimizeTranposeWithRank7orMoreEffectiveRank4(%arg0: tensor<56x8x56x1x1x1x7xf32> ) -> (tensor<1x1x8x56x56x7x1xf32>) { + %cst = arith.constant dense<[4, 5, 1, 2, 0, 6, 3]> : tensor<7xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<56x8x56x1x1x1x7xf32>, tensor<7xi32>) -> tensor<1x1x8x56x56x7x1xf32> + return %0 : tensor<1x1x8x56x56x7x1xf32> + // CHECK-DAG: %cst = arith.constant dense<[56, 8, 56, 7]> : tensor<4xi32> + // CHECK-DAG: %cst_0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> + // CHECK-DAG: %cst_1 = arith.constant dense<[1, 1, 8, 56, 56, 7, 1]> : tensor<7xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<56x8x56x1x1x1x7xf32>, tensor<4xi32>) -> tensor<56x8x56x7xf32> + // CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor<56x8x56x7xf32>, tensor<4xi32>) -> tensor<8x56x56x7xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_1) : (tensor<8x56x56x7xf32>, tensor<7xi32>) -> tensor<1x1x8x56x56x7x1xf32> + // CHECK: return %2 +} + diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index 97e3a647b042a6..33e5cca6e5de17 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -8,8 +8,8 @@ func.func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x func.return %24 : tensor<1x4xf32> // CHECK-LABEL: main // separate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252 -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> // CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ // CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] @@ -46,8 +46,8 @@ func.func @testLSTMAsymAttributeTrue(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {asymmetric_quantize_inputs = true, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %24 : tensor<1x4xf32> -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> // CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ // CHECK: }) {asymmetric_quantize_inputs = true, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] @@ -63,8 +63,8 @@ func.func @testLSTMAsymAttributeFalse(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4x %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %24 : tensor<1x4xf32> -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> // CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ // CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] @@ -80,8 +80,8 @@ func.func @testLSTMAsymAttributeDefault(%arg0: tensor<1x4xf32>, %arg1: tensor<4x %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %24 : tensor<1x4xf32> -// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> -// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES0:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> +// CHECK-DAG: %[[RES1:.*]] = "tfl.pseudo_const"() {tfl.is_variable, value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32> // CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ({ // CHECK: }) {asymmetric_quantize_inputs = false, cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = #tfl, proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES2]] diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir deleted file mode 100644 index 6b76b31c9a52bf..00000000000000 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s - -// CHECK: tfl.metadata_buffer = [3 : i32, 7 : i32] -module attributes {tfl.metadata_buffer = [3 : i32, 7 : i32]} { - func.func @main(%arg0: tensor, %arg1: tensor<3x2xi32>) -> tensor<3x2xi32> { - %0 = "tfl.add" (%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor, tensor<3x2xi32>) -> tensor<3x2xi32> - func.return %0 : tensor<3x2xi32> - } -} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir index 64567f5c3d5d68..708dd562195398 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir @@ -320,7 +320,7 @@ func.func @dynamic_update_slice(%arg0: tensor<4x4xi64>, %arg1: tensor<2x3xi64>, func.func @dyanmic_slice(%arg0: tensor<3x3xi64>, %arg1: tensor, %arg2: tensor) -> tensor<3x3xi64> { %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) { - slice_sizes = dense<[3, 3]> : tensor<2xi64> + slice_sizes = array } : (tensor<3x3xi64>, tensor, tensor) -> tensor<3x3xi64> return %0 : tensor<3x3xi64> } @@ -524,7 +524,7 @@ func.func @gather(%operand: tensor<3x4x2xi32>, %start_indices: tensor<2x3x2xi64> // CHECK-NEXT:} func.func @transpose(%arg0: tensor<2x3x2xi32>) -> tensor<2x3x2xi32> { - %0 = "stablehlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32> + %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32> return %0 : tensor<2x3x2xi32> } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir new file mode 100644 index 00000000000000..0914fc37016771 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/variable.mlir @@ -0,0 +1,8 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s + +// CHECK-LABEL: main +func.func @main() -> tensor<3x2xi32> { + // CHECK: "tfl.pseudo_const"() {tfl.is_variable, value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3x2xi32>, tfl.is_variable} : () -> tensor<3x2xi32> loc("variable") + func.return %0 : tensor<3x2xi32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/deduplicate_const.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/deduplicate_const.mlir new file mode 100644 index 00000000000000..c0c1bf4f70dc7f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/deduplicate_const.mlir @@ -0,0 +1,93 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +module { +func.func @add(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> attributes {tf.entry_function = {inputs = "serving_default_x", outputs = "outputs"}} { + %0 = "tfl.pseudo_const" () {value = dense<[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + %1 = "tfl.add" (%0, %arg0) {fused_activation_function = "NONE"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + func.return %1 : tensor<3x2xf32> +} + +func.func @sub(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> attributes {tf.entry_function = {inputs = "serving_default_x", outputs = "outputs"}} { + %0 = "tfl.pseudo_const" () {value = dense<[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + %1 = "tfl.sub" (%0, %arg0) {fused_activation_function = "NONE"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + func.return %1 : tensor<3x2xf32> +} +} + +// CHECK: { +// CHECK: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "serving_default_x", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "tfl.pseudo_const", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "outputs", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: } ], +// CHECK: name: "add" +// CHECK-NEXT: }, { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "serving_default_x", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "tfl.pseudo_const1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "outputs", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK: name: "sub" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64 ] +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index c1504b979afa5b..3bab94ea490f37 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -273,7 +273,7 @@ func.func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf3 // CHECK-NEXT: }, { // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_asym_attr.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_asym_attr.mlir index 62cf9336ad0f29..6e5a70a6b5bb66 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_asym_attr.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_asym_attr.mlir @@ -273,7 +273,7 @@ func.func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf3 // CHECK-NEXT: }, { // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir deleted file mode 100644 index f53f3954f14211..00000000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s - -module attributes {tfl.metadata_buffer = [3 : i32, 7 : i32]} { - func.func @main(%arg0: tensor, %arg1: tensor<3x2xi32>) -> tensor<3x2xi32> { - %0 = "tfl.add" (%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor, tensor<3x2xi32>) -> tensor<3x2xi32> - func.return %0 : tensor<3x2xi32> - } -} - -// CHECK: metadata_buffer: [ 3, 7 ], -// CHECK-NEXT: metadata: \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir index 8253e8215f9d38..e5c9b4802c15e4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/signature_def.mlir @@ -44,7 +44,7 @@ // CHECK-NEXT: has_rank: true // CHECK-NEXT: }, { // CHECK-NEXT: shape: [ 5, 384 ], -// CHECK-NEXT: buffer: 5, +// CHECK-NEXT: buffer: 4, // CHECK-NEXT: name: "arith.constant2", // CHECK-NEXT: quantization: { // CHECK-EMPTY: diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/u16_quant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/u16_quant.mlir new file mode 100644 index 00000000000000..251e8bd389cfd1 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/u16_quant.mlir @@ -0,0 +1,19 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +func.func @main(%arg0: tensor<*x!quant.uniform>) -> tensor<*x!quant.uniform> { +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ ], +// CHECK-NEXT: type: UINT16, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: scale: [ 2.0 ], +// CHECK-NEXT: zero_point: [ 37 ] +// CHECK: } +// CHECK-NEXT: } ], + return %arg0 : tensor<*x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 669f6068e948b5..738b413c09268b 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input=always +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s func.func @main(tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4x4xf32> { // CHECK: { @@ -298,7 +298,7 @@ func.func @main(tensor<4x4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4x // CHECK-NEXT: }, { // CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-EMPTY: // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variable.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variable.mlir new file mode 100644 index 00000000000000..2b393f7fecaa8d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/variable.mlir @@ -0,0 +1,40 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +func.func @main() -> tensor<3x2xi32> { + %0 = "tfl.pseudo_const" () {value = dense<0> : tensor<3x2xi32>, tfl.is_variable} : () -> tensor<3x2xi32> loc("variable") + func.return %0 : tensor<3x2xi32> +} + +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 3, 2 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: name: "variable", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ ], +// CHECK-NEXT: outputs: [ 0 ], +// CHECK-NEXT: operators: [ ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ {{.*}} ] +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ {{.*}} ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 2 +// CHECK-NEXT: } ] +// CHECK-NEXT: signature_defs: [ ] +// CHECK-NEXT: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index fff00820ce353b..0769e768507ee7 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -226,7 +226,7 @@ func.func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1 // CHECK-LABEL: matmulNoTransposeAOrB // CHECK: %[[RES:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor // CHECK: %[[TRANS:.*]] = "tf.Transpose"(%arg1, %[[RES]]) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> - // CHECK: %[[MM:.*]] = "tf.MatMul"(%arg0, %[[TRANS]]) <{transpose_a = false, transpose_b = true}> : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: %[[MM:.*]] = "tf.MatMul"(%arg0, %[[TRANS]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = true}> : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32> // CHECK: return %[[MM]] : tensor<1x1000xf32> } @@ -238,7 +238,7 @@ func.func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000 // CHECK: %[[RES:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor // CHECK: %[[TRANS1:.*]] = "tf.Transpose"(%arg0, %[[RES]]) : (tensor<1x1280xf32>, tensor) -> tensor<*xf32> // CHECK: %[[TRANS2:.*]] = "tf.Transpose"(%arg1, %[[RES]]) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> - // CHECK: %[[MM:.*]] = "tf.MatMul"(%[[TRANS1]], %[[TRANS2]]) <{transpose_a = false, transpose_b = true}> : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: %[[MM:.*]] = "tf.MatMul"(%[[TRANS1]], %[[TRANS2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = true}> : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> // CHECK: return %[[MM]] : tensor<1x1000xf32> } @@ -718,7 +718,7 @@ func.func @QuantDequantTranspose(%arg0: tensor<2x3xf32>) -> (tensor<2x4xf32>) { // CHECK: %[[QUANT:.*]] = "tfl.quantize"(%[[CST_0]]) {qtype = tensor<3x4x!quant.uniform>} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> // CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"(%[[QUANT]]) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> // CHECK: %[[TRANSPOSE:.*]] = "tf.Transpose"(%[[DEQUANT]], %[[CST]]) : (tensor<3x4xf32>, tensor) -> tensor<*xf32> - // CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[TRANSPOSE]]) <{transpose_a = false, transpose_b = true}> : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<2x4xf32> + // CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[TRANSPOSE]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = true}> : (tensor<2x3xf32>, tensor<*xf32>) -> tensor<2x4xf32> // CHECK: return %[[MATMUL]] : tensor<2x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index ec970065be3576..8a3abc94e2af57 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -407,7 +407,7 @@ void DenseToSparsePass::runOnOperation() { } if (result.needs_densify) { - const auto value = op->getOperand(operand); + auto value = op->getOperand(operand); auto densify = builder.create(op->getLoc(), value.getType(), value); value.replaceAllUsesWith(densify); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index a99d4a9a1c688e..a2ea10fe199736 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -387,13 +387,13 @@ def LegalizeDiv : Pat<(TF_DivOp $lhs, $rhs), // fall through to here and convert to TF Lite BatchMatMul. // TODO(b/207064634): CreateEmptyBoolAttr is a temporary workaround for this bug. def LegalizeBatchMatMulV3UnknownBatch : Pat< - (TF_BatchMatMulV3Op $lhs, $rhs, $adj_x, $adj_y), + (TF_BatchMatMulV3Op $lhs, $rhs, $adj_x, $adj_y, $grad_x, $grad_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y, CreateEmptyBoolAttr:$adj_y)>; def LegalizeBatchMatMulV2UnknownBatch : Pat< - (TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), + (TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y, $grad_x, $grad_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y, CreateEmptyBoolAttr:$adj_y)>; def LegalizeBatchMatMulUnknownBatch : Pat< - (TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), + (TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y, $grad_x, $grad_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y, CreateEmptyBoolAttr:$adj_y)>; def LegalizeFakeQuantWithMinMaxVars: Pat< diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 64bff681053f6e..2ed3c0519d8526 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -362,34 +362,6 @@ TypeAttr RescaleQtype(Type input, Attribute factor) { return quant::RescaleQuantizedType(input, factor); } -// Returns shape of a ranked tensor. -// Precondition: output_val's is ranked tensor. -// Returns a truncated shape when `truncate` is set to true. -DenseElementsAttr GetShape(Value output_val, bool truncate = false) { - auto output_shape = output_val.getType().dyn_cast().getShape(); - - SmallVector shape; - shape.reserve(output_shape.size()); - - bool needs_truncation = true; - for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) { - int64_t dim = output_shape[dim_idx]; - if (truncate && needs_truncation && dim == 1) { - continue; - } else if (needs_truncation && dim != 1) { - needs_truncation = false; - } - shape.push_back(ShapedType::isDynamic(dim) ? -1 - : static_cast(dim)); - } - - return mlir::DenseElementsAttr::get( - RankedTensorType::get( - {static_cast(shape.size())}, - mlir::IntegerType::get(output_val.getContext(), 32)), - llvm::ArrayRef(shape)); -} - // Utility function to map final permutation to initial permutation // initial -> permutation1 -> permutation2 -> final DenseElementsAttr RemapPermutation(Value permutation1, Value permutation2) { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index dccf57ab1ecf40..e1e3d766ed3e5b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -41,11 +41,6 @@ class HasRankAtMost : Constraint< CPred<"$0.getType().cast().hasRank() && " "$0.getType().cast().getRank() <= " # n>>; -// Checks if the value has rank at most 'n'. -class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; - // Checks if the value has rank 'n'. class HasRank : Constraint< CPred<"$0.getType().cast().hasRank() && " @@ -698,14 +693,10 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, } } -// Returns shape of a ranked tensor. -// if called without a ranked tensor it will fail. -def GetShape: NativeCodeCall<"GetShape($0)">; - // Returns truncated shape of a ranked-tensor. -// Truncated, here, means eliminating any contiguous 1s' in the lower +// Prefix-Truncated, here, means eliminating any contiguous 1s' in the lower // dimentions of the tensor -def GetTruncatedShape: NativeCodeCall<"GetShape($0, true)">; +def GetPrefixTruncatedShape: NativeCodeCall<"GetShape($0, true)">; // Returns True if the operand type is RankedTensorType and valid. def HasValidRankedTensor : Constraint().getNumDynamicDims() <= 1">>; // Check if the truncated shape of the lhs is equal to the shape of rhs -def IsTruncatedShapeEqualTo : Constraint>; def ConvertSqueezeToReshape : Pat< @@ -735,9 +726,9 @@ def ConvertTrasposeReshapeTransposeToReshape : Pat< (TFL_TransposeOp:$first_transpose $input, $permutation2), $shape), $permutation1), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetTruncatedShape $input))), - [(IsTruncatedShapeEqualTo $first_transpose, $middle_reshape), - (IsTruncatedShapeEqualTo $input, $second_transpose)]>; + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetPrefixTruncatedShape $input))), + [(IsPrefixTruncatedShapeEqualTo $first_transpose, $middle_reshape), + (IsPrefixTruncatedShapeEqualTo $input, $second_transpose)]>; // TODO(b/294385379): This pattern only appears when we convert // from shlo due to differences in broadcasting behavior diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index abd57fe7372ef8..c625b329be6413 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -51,21 +51,21 @@ class TFi32 : ConstantAttr(v)>; // Matmul without transpose on b to matmul with explicit transpose op and // transposed b. def ConvertMatmulWithoutTransposeToWithTranspose : - Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse), + Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse, $grad_a, $grad_b), (TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp /*start=*/(TF_RankOp $b), /*limit=*/(TF_ConstOp TFi32<0>), /*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), - $at, ConstBoolAttrTrue)>; + $at, ConstBoolAttrTrue, $grad_a, $grad_b)>; // Matmul with transpose on a to matmul with explicit transpose op and a not // transposed. -def ConvertMatmulWithTranspose : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt), +def ConvertMatmulWithTranspose : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt, $grad_a, $grad_b), (TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp /*start=*/(TF_RankOp $a), /*limit=*/(TF_ConstOp TFi32<0>), /*delta=*/(TF_ConstOp TFi32<-1>)), (TF_ConstOp TFi32<1>))), $b, - ConstBoolAttrFalse, $bt)>; + ConstBoolAttrFalse, $bt, $grad_a, $grad_b)>; // Partially supported in TFLite, treated as passthrough IdentityOp def ConvertCheckNumerics : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 9fe43f34b256cf..6130bab6531ba2 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -107,6 +107,34 @@ inline ShapedType GetTransposedType(Value input, return transposed_type; } +// Returns shape of a ranked tensor. +// Precondition: output_val's is ranked tensor. +// Returns a truncated shape when `truncate` is set to true. +inline DenseElementsAttr GetShape(Value output_val, bool truncate = false) { + auto output_shape = output_val.getType().dyn_cast().getShape(); + + SmallVector shape; + shape.reserve(output_shape.size()); + + bool needs_truncation = true; + for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) { + int64_t dim = output_shape[dim_idx]; + if (truncate && needs_truncation && dim == 1) { + continue; + } else if (needs_truncation && dim != 1) { + needs_truncation = false; + } + shape.push_back(ShapedType::isDynamic(dim) ? -1 + : static_cast(dim)); + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(shape.size())}, + mlir::IntegerType::get(output_val.getContext(), 32)), + llvm::ArrayRef(shape)); +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index c2b953d2cf0585..e64b591ae78eda 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -17,6 +17,16 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/IR/PatternBase.td" + +// Returns shape of a ranked tensor. +// if called without a ranked tensor it will fail. +def GetShape: NativeCodeCall<"GetShape($0)">; + +// Checks if the value has rank at most 'n'. +class HasRankAtLeast : Constraint< + CPred<"$0.getType().cast().hasRank() && " + "$0.getType().cast().getRank() >= " # n>>; // Checks value is not produced by a TFL_Quant or // from TFL_Quant Op with same quant type. diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index acc7bd1a8fb01e..afc088517dc35f 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -1,3 +1,4 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -7,6 +8,17 @@ package( licenses = ["notice"], ) +bool_flag( + name = "disable_mlir", + build_setting_default = False, +) + +config_setting( + name = "disable_mlir_config", + flag_values = {":disable_mlir": "True"}, + visibility = ["//visibility:public"], +) + cc_library( name = "mlir", srcs = ["mlir.cc"], diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc index 6042a896709d9e..8c82fc9bc12b42 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -29,8 +29,7 @@ PYBIND11_MODULE(filecheck_wrapper, m) { llvm::SMLoc()); SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check), llvm::SMLoc()); - llvm::Regex regex = fc.buildCheckPrefixRegex(); - fc.readCheckFile(SM, llvm::StringRef(check), regex); + fc.readCheckFile(SM, llvm::StringRef(check)); return fc.checkInput(SM, llvm::StringRef(input)); }); } diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD new file mode 100644 index 00000000000000..a39f7e5a64d268 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -0,0 +1,123 @@ +load("@llvm-project//mlir:tblgen.bzl", "td_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # By default, these targets should only be used within the quantization library. + default_visibility = [ + "//learning/brain/mlir/quantization:__subpackages__", + "//tensorflow/compiler/mlir/quantization:__subpackages__", + ], + licenses = ["notice"], +) + +td_library( + name = "lift_as_function_call_td_files", + srcs = [ + "lift_as_function_call.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//mlir:FuncTdFiles", + ], +) + +cc_library( + name = "lift_as_function_call", + srcs = ["lift_as_function_call.cc"], + hdrs = ["lift_as_function_call.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/stablehlo:stablehlo_type_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:framework_lite", + "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "lift_as_function_call_test", + srcs = ["lift_as_function_call_test.cc"], + deps = [ + ":lift_as_function_call", + ":test_base", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + +cc_library( + name = "test_base", + testonly = 1, + srcs = [], + hdrs = ["test_base.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:QuantOps", + "@stablehlo//:stablehlo_ops", + ], +) + +cc_library( + name = "attrs_and_constraints", + srcs = [ + "attrs_and_constraints.cc", + ], + hdrs = [ + "attrs_and_constraints.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +td_library( + name = "quant_td_files", + srcs = [ + "attrs_and_constraints.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call_td_files", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:FuncTdFiles", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc similarity index 79% rename from tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc rename to tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index a72098d3fa8aae..a5d4f745a7d02a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -12,16 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" - -#include +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -namespace mlir { -namespace quant { +namespace mlir::quant { bool HasQuantizedTensors(Operation* op) { if (!IsOpQuantizable(op)) return false; @@ -72,5 +76,4 @@ SmallVector CloneOpWithReplacedOperands( return builder.clone(*op, mapping)->getResults(); } -} // namespace quant -} // namespace mlir +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h similarity index 90% rename from tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h rename to tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 320b6b93aa536d..791e608dc064dc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ +#include #include -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -namespace mlir { -namespace quant { +namespace mlir::quant { constexpr char kQuantizeFuncName[] = "quantize_i8"; constexpr char kDequantizeFuncName[] = "dequantize_i8"; @@ -132,6 +132,6 @@ bool AreSplatValuesEqual(Value x, Value y) { SmallVector CloneOpWithReplacedOperands( OpBuilder &builder, Operation *op, const SmallVector &new_operands); -} // namespace quant -} // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_UTILS_H_ +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td similarity index 92% rename from tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td rename to tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td index 654e4af58d3fc6..a5d1d8544ae931 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td @@ -147,3 +147,15 @@ def GetDefiningOp : NativeCodeCall<"$0.getDefiningOp()">; def CloneOpWithReplacedOperands : NativeCodeCall< "CloneOpWithReplacedOperands(" "$_builder, $0, llvm::SmallVector{$1...}).front()">; + +// Checks whether the value of a constant equals the given float, regardless +// of the tensor dimension. +class FloatValueEquals : Constraint>; + +// Fetches the default or null attribute, used for pattern matching. +def DefaultOrNullAttr : NativeCodeCall<"DefaultOrNullAttr($_builder, $0)">; + +// Returns true if the given op is a StableHLO constant op. +def IsStableHLOConstantOp : Constraint($0.getDefiningOp())">>; + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc similarity index 73% rename from tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc rename to tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index 37d9b56ac1a7de..d74c8a952c8c24 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include #include @@ -22,50 +22,74 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/ir/types/dialect.h" #include "tensorflow/core/platform/mutex.h" -namespace mlir { -namespace quant { +namespace mlir::quant { + +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Default platform for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; +// Name of `tf.XlaCallModule`'s dictionary attribute for keeping the +// deserialized stablehlo module's attributes. +constexpr llvm::StringRef kStablehloModuleAttrsAttrName = + "_stablehlo_module_attrs"; +// Attribute required for running shape refinement pass enabled in XlaCallModule +// version 8 and above. +constexpr llvm::StringRef kUsesShapePolymorphismAttr = + "jax.uses_shape_polymorphism"; // Checks if the op is inside a lifted function. -bool IsInLiftedFunc(Operation *op) { - return op->getParentOfType()->hasAttr(kFusedFunctionAttr); +bool IsInLiftedFunc(Operation& op) { + return op.getParentOfType()->hasAttr(kFusedFunctionAttr); } // Inserts the function to the symbol table of the module thread-safely. -StringAttr InsertToSymbolTable(Operation *module, Operation *function, - const std::string &func_name) { - static tensorflow::mutex *mtx = new tensorflow::mutex(); +StringAttr InsertToSymbolTable(Operation& module, Operation& function, + const std::string& func_name) { + static tensorflow::mutex* mtx = new tensorflow::mutex(); tensorflow::mutex_lock lock(*mtx); - SymbolTable symbol_table(module); + SymbolTable symbol_table(&module); std::string unique_name = func_name; int32_t uniquing_counter = 0; while (symbol_table.lookup(unique_name) != nullptr) { ++uniquing_counter; unique_name = func_name + "_" + std::to_string(uniquing_counter); } - function->setAttr("sym_name", - StringAttr::get(module->getContext(), unique_name)); - return symbol_table.insert(function); + function.setAttr("sym_name", + StringAttr::get(module.getContext(), unique_name)); + return symbol_table.insert(&function); } // Creates the TF::PartitionedCallOp with the given arguments and output types. @@ -100,15 +124,16 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, tf_type::ShapeAttr::get(ctx, result_type.cast())); } auto empty_array_attr = ArrayAttr::get(ctx, {}); + auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); TF::XlaCallModuleOp call_op = builder.create( location, /*output=*/output_types, /*args=*/args, - /*version=*/5, /*module=*/"", + /*version=*/kDefaultVersion, /*module=*/"", /*Sout=*/ArrayAttr::get(ctx, shape_attrs), /*dim_args_spec=*/empty_array_attr, - /*platforms=*/empty_array_attr, + /*platforms=*/platforms, /*function_list=*/empty_array_attr, /*has_token_input_output=*/false, /*disabled_checks=*/empty_array_attr); @@ -130,6 +155,12 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, builder.getStringAttr(llvm::StringRef( std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + call_op->setAttr(kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); + return call_op.getOutput(); } @@ -152,14 +183,14 @@ ValueRange createFunctionCallOp(OpBuilder builder, Location location, // Finds ops in the paths from arguments to results. The ops is listed in an // order that the former ops shouldn't have any dependencies on the later ones. -llvm::SmallVector FindOpsFromArgumentsToResults( - const llvm::SmallVector &arguments, - const llvm::SmallVector &results) { +llvm::SmallVector FindOpsFromArgumentsToResults( + const llvm::SmallVector& arguments, + const llvm::SmallVector& results) { std::queue value_queue; for (Value result : results) { value_queue.push(result); } - absl::flat_hash_set argument_set; + absl::flat_hash_set argument_set; for (Value argument : arguments) { argument_set.insert(argument.getImpl()); } @@ -167,15 +198,15 @@ llvm::SmallVector FindOpsFromArgumentsToResults( // Searching for ops from results to arguments. Duplicate ops in the op stack // are intentional in order to make sure the op on the top of the stack // doesn't depends on any ops below it. - std::stack op_stack; + std::stack op_stack; while (!value_queue.empty()) { Value current_value = value_queue.front(); value_queue.pop(); - Operation *defining_node = current_value.getDefiningOp(); + Operation* defining_node = current_value.getDefiningOp(); if (defining_node == nullptr) continue; op_stack.push(defining_node); - for (const auto &arg : defining_node->getOperands()) { + for (const auto& arg : defining_node->getOperands()) { if (!argument_set.contains(arg.getImpl())) { value_queue.push(arg); } @@ -183,10 +214,10 @@ llvm::SmallVector FindOpsFromArgumentsToResults( } // Remove duplicate ops from the op stack. - llvm::SmallVector sorted_ops; - absl::flat_hash_set unique_ops; + llvm::SmallVector sorted_ops; + absl::flat_hash_set unique_ops; while (!op_stack.empty()) { - Operation *current_op = op_stack.top(); + Operation* current_op = op_stack.top(); op_stack.pop(); if (unique_ops.contains(current_op)) continue; sorted_ops.push_back(current_op); @@ -206,21 +237,20 @@ llvm::SmallVector FindOpsFromArgumentsToResults( // identifiers. // This function returns success if all attributes could be found. LogicalResult SetAttributeMap( - MLIRContext *context, const llvm::SmallVector &attributes, - const llvm::SmallVector &ops) { + MLIRContext& context, const llvm::SmallVector& attributes, + const llvm::SmallVector& ops) { // A map to find which operation an attribute belongs to. // The key for this map uses the entire NamedAttribute object, i.e. the // {attribute_name, attribute_value} pair. - llvm::SmallDenseMap attr_to_op_map; - for (Operation *op : ops) { - for (const auto &named_attr : op->getAttrs()) { + llvm::SmallDenseMap attr_to_op_map; + for (Operation* op : ops) { + for (const NamedAttribute named_attr : op->getAttrs()) { attr_to_op_map.insert({named_attr, op}); } } for (int idx : llvm::seq(0, attributes.size())) { - const NamedAttribute &attribute = attributes[idx]; - + const NamedAttribute& attribute = attributes[idx]; // Skip the following steps if the attribute value is `NullAttribute`. if (const auto string_attr = attribute.getValue().dyn_cast_or_null(); @@ -229,27 +259,38 @@ LogicalResult SetAttributeMap( continue; } - if (attr_to_op_map.count(attribute) == 0) { - mlir::emitError(UnknownLoc::get(context), + if (std::find_if( + attr_to_op_map.begin(), attr_to_op_map.end(), [&](auto attr_op) { + return std::get<0>(attr_op).getName() == attribute.getName(); + }) == attr_to_op_map.end()) { + mlir::emitError(UnknownLoc::get(&context), "Could not find attribute: " + attribute.getName().str()); return failure(); } - Operation *owner_op = attr_to_op_map[attribute]; - - std::string new_attr_map_str{}; - if (owner_op->hasAttr(kAttrMapAttribute)) { - new_attr_map_str = - owner_op->getAttrOfType(kAttrMapAttribute).str(); - absl::StrAppend(&new_attr_map_str, ","); + Operation* owner_op; + for (const auto& [attr, val] : attr_to_op_map) { + if (attr.getName() == attribute.getName()) owner_op = val; } + if (stablehlo::IsStablehloOp(owner_op)) { + owner_op->setAttr(StringRef(attribute.getName()), attribute.getValue()); + } else { + owner_op = attr_to_op_map[attribute]; + + std::string new_attr_map_str{}; + if (owner_op->hasAttr(kAttrMapAttribute)) { + new_attr_map_str = + owner_op->getAttrOfType(kAttrMapAttribute).str(); + absl::StrAppend(&new_attr_map_str, ","); + } - // Append ":". Ex) "0:transpose_a". - const std::string identifier = std::to_string(idx); - const mlir::StringAttr attribute_name = attribute.getName(); - absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str()); - owner_op->setAttr(kAttrMapAttribute, - StringAttr::get(context, new_attr_map_str)); + // Append ":". Ex) "0:transpose_a". + const std::string identifier = std::to_string(idx); + const mlir::StringAttr attribute_name = attribute.getName(); + absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str()); + owner_op->setAttr(kAttrMapAttribute, + StringAttr::get(&context, new_attr_map_str)); + } } return success(); } @@ -257,15 +298,15 @@ LogicalResult SetAttributeMap( // Creates a function to wrap the section between arguments and results. llvm::SmallVector LiftAsFunctionCall( OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector &arguments, - const llvm::SmallVector &results, - const llvm::SmallVector &attributes) { - MLIRContext *context = builder.getContext(); + StringRef func_name, const llvm::SmallVector& arguments, + const llvm::SmallVector& results, + const llvm::SmallVector& attributes) { + MLIRContext* context = builder.getContext(); if (results.empty()) { mlir::emitError(UnknownLoc::get(context), "No result values specified"); return {}; } - Operation *result_op = results[0].getDefiningOp(); + Operation* result_op = results[0].getDefiningOp(); auto module = result_op->getParentOfType(); // Create a private function and copy all ops between arguments and results. @@ -277,7 +318,7 @@ llvm::SmallVector LiftAsFunctionCall( auto func_type = FunctionType::get(context, arg_types, result_types); llvm::SmallVector arg_locs; - for (const auto &arg : arguments) { + for (const auto& arg : arguments) { arg_locs.push_back(arg.getLoc()); } auto wrap_func = builder.create(location, func_name, func_type); @@ -298,7 +339,7 @@ llvm::SmallVector LiftAsFunctionCall( auto cloning_ops = FindOpsFromArgumentsToResults(arguments, results); // Set the location of call op to QuantizationUnitLoc if found. Location call_op_loc = location; - for (Operation *op : cloning_ops) { + for (Operation* op : cloning_ops) { std::optional unit = FindQuantizationUnitFromLoc(op->getLoc()); if (unit.has_value()) { @@ -306,10 +347,10 @@ llvm::SmallVector LiftAsFunctionCall( } } - if (failed(SetAttributeMap(context, attributes, cloning_ops))) { + if (failed(SetAttributeMap(*context, attributes, cloning_ops))) { current_func.emitError() << "Some attributes couldn't be found."; } - for (Operation *op : cloning_ops) { + for (Operation* op : cloning_ops) { builder.clone(*op, mapping); } @@ -321,7 +362,7 @@ llvm::SmallVector LiftAsFunctionCall( // Create a function call to the newly created function. StringAttr new_func_name = - InsertToSymbolTable(module, wrap_func, func_name.str()); + InsertToSymbolTable(*module, *wrap_func, func_name.str()); builder.setInsertionPointAfter(result_op); ValueRange new_results = createFunctionCallOp(builder, call_op_loc, call_op_type, @@ -331,15 +372,15 @@ llvm::SmallVector LiftAsFunctionCall( llvm::SmallVector LiftAsFunctionCall( OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector &arguments, - const llvm::SmallVector &results) { + StringRef func_name, const llvm::SmallVector& arguments, + const llvm::SmallVector& results) { llvm::SmallVector attributes; return LiftAsFunctionCall(builder, location, call_op_type, func_name, arguments, results, attributes); } llvm::SmallVector AppendToVector( - const llvm::SmallVector &arguments, Value append) { + const llvm::SmallVector& arguments, Value append) { llvm::SmallVector ret(arguments); ret.push_back(append); return ret; @@ -422,5 +463,4 @@ bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr) { rhs_out_idx_start >= batch_dim_size; } -} // namespace quant -} // namespace mlir +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h similarity index 76% rename from tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h rename to tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index 83f1ed2ce6d59d..c796fbbca32a2f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_LIFT_AS_FUNCTION_CALL_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_LIFT_AS_FUNCTION_CALL_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project -// This header file defines common utils used by TF-Quant transformation -// passes to lift op compositions to a function. -namespace mlir { -namespace quant { +namespace mlir::quant { // This attribute will be set for functions created by this pass. +// Presence of this attribute will mark the function as quantization target. inline constexpr absl::string_view kFusedFunctionAttr = "tf_quant.composite_function"; // The keyword to detect if this is a `NullAttribute`. @@ -43,7 +43,7 @@ inline constexpr absl::string_view kOriginalStablehloEntryFunctionAttrName = enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; // Checks if the op is inside a lifted function. -bool IsInLiftedFunc(Operation *op); +bool IsInLiftedFunc(Operation &op); // Checks if the given einsum op is supported for XlaDotV2 quantization. bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr); @@ -70,6 +70,6 @@ llvm::SmallVector LiftAsFunctionCall( llvm::SmallVector AppendToVector( const llvm::SmallVector &arguments, Value append); -} // namespace quant -} // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_LIFT_AS_FUNCTION_CALL_UTILS_H_ +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td similarity index 96% rename from tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td rename to tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td index 6110a38c721f98..a4437b50ac0cf0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ class NamedAttr : // Checks if the value is not defined inside a lifted function by checking the // `tf_quant.composite_function` attribute. def IsNotInLiftedFunc : - Constraint>; + Constraint>; // Checks if the given einsum op is supported for XlaDotV2 quantization. def IsEinsumSupportedByXlaDotV2 : diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc new file mode 100644 index 00000000000000..4947fcd910e64b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" + +#include +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant::common { +namespace { + +class LiftAsFunctionCallTest : public QuantizationTestBase {}; + +constexpr absl::string_view kModuleLifted = R"mlir( + module { + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + } +)mlir"; + +TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { + OwningOpRef module_op_ref = ParseModuleOpString(kModuleLifted); + func::FuncOp composite_dot_general_fn = + GetFunctionFromModule(*module_op_ref, "composite_dot_general_fn_1"); + Operation* dot_general_op = + FindOperationOfType( + composite_dot_general_fn); + EXPECT_TRUE(IsInLiftedFunc(*dot_general_op)); +} + +constexpr absl::string_view kModuleStableHlo = R"mlir( + module { + func.func private @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + } +)mlir"; + +TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { + OwningOpRef module_op_ref = ParseModuleOpString(kModuleStableHlo); + func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + Operation* dot_general_op = + FindOperationOfType(main_fn); + + const SmallVector& attributes = { + builder_.getNamedAttr("precision_config", + builder_.getArrayAttr(SmallVector( + 1, stablehlo::PrecisionAttr::get( + &ctx_, stablehlo::Precision::DEFAULT)))), + }; + Operation* lifted_op = + LiftAsFunctionCall(builder_, dot_general_op->getLoc(), + FunctionCallOpType::TFXlaCallModuleOp, + "composite_dot_general_fn", + dot_general_op->getOperands(), + dot_general_op->getResults(), attributes)[0] + .getDefiningOp(); + const auto entry_function_symbol_ref = + lifted_op->getAttrOfType("_entry_function"); + SymbolTable symbol_table(*module_op_ref); + auto entry_func = dyn_cast_or_null( + symbol_table.lookup(entry_function_symbol_ref.getValue())); + Operation* lifted_dot_general_op = + FindOperationOfType(entry_func); + + EXPECT_TRUE(isa(lifted_op)); + EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), + "composite_dot_general_fn_1"); + EXPECT_EQ( + lifted_dot_general_op->getAttr("precision_config").cast(), + builder_.getArrayAttr(SmallVector( + 1, stablehlo::PrecisionAttr::get(&ctx_, + stablehlo::Precision::DEFAULT)))); +} + +TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { + OwningOpRef module_op_ref = ParseModuleOpString(kModuleStableHlo); + func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + Operation* dot_general_op = + FindOperationOfType(main_fn); + Operation* lifted_op = + LiftAsFunctionCall( + builder_, dot_general_op->getLoc(), + FunctionCallOpType::TFXlaCallModuleOp, "composite_dot_general_fn", + dot_general_op->getOperands(), dot_general_op->getResults())[0] + .getDefiningOp(); + EXPECT_TRUE(isa(lifted_op)); + EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), + "composite_dot_general_fn_1"); +} + +TEST_F(LiftAsFunctionCallTest, EinsumSupportedForXlaDotV2Succeeds) { + StringAttr einsum_supported_by_xla_dot_v2_attr = + builder_.getStringAttr("ijk,ikm->ijm"); + StringAttr einsum_one_operand = builder_.getStringAttr("ijk->ikj"); + StringAttr einsum_ellipsis = builder_.getStringAttr("...gse->...gs"); + EXPECT_TRUE(IsEinsumSupportedByXlaDotV2(einsum_supported_by_xla_dot_v2_attr)); + EXPECT_FALSE(IsEinsumSupportedByXlaDotV2(einsum_one_operand)); + EXPECT_FALSE(IsEinsumSupportedByXlaDotV2(einsum_ellipsis)); +} + +} // namespace +} // namespace mlir::quant::common diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h new file mode 100644 index 00000000000000..ad847a29477779 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ + +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir::quant::common { + +using ::testing::Test; + +class QuantizationTestBase : public Test { + protected: + QuantizationTestBase() { + ctx_.loadDialect(); + } + + // Parses `module_op_str` to create a `ModuleOp`. Checks whether the created + // module op is valid. + OwningOpRef ParseModuleOpString( + const absl::string_view module_op_str) { + auto module_op_ref = parseSourceString(module_op_str, &ctx_); + EXPECT_TRUE(module_op_ref); + return module_op_ref; + } + + // Gets the function with the given name from the module. + func::FuncOp GetFunctionFromModule(ModuleOp module, + absl::string_view function_name) { + SymbolTable symbol_table(module); + return symbol_table.lookup(function_name); + } + + // Returns the first operation with the given type in the function. + template + OpType FindOperationOfType(func::FuncOp function) { + for (auto op : function.getBody().getOps()) { + return op; + } + return nullptr; + } + + mlir::MLIRContext ctx_{}; + OpBuilder builder_{&ctx_}; +}; + +} // namespace mlir::quant::common + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 640cd2e6cb7366..6ab58f78aac025 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -25,13 +25,13 @@ package( licenses = ["notice"], ) -# TODO(b/264218457): Add quantize and post_quantize passes. cc_library( name = "passes", srcs = [ "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions_fusion.inc", "passes/lift_quantizable_spots_as_functions_simple.inc", + "passes/populate_shape.cc", "passes/post_quantize.cc", "passes/prepare_quantize.cc", "passes/quantize.cc", @@ -40,6 +40,7 @@ cc_library( "passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc", "passes/restore_function_name.cc", "passes/unfuse_mhlo_batch_norm.cc", + "passes/unwrap_xla_call_module_op.cc", ], hdrs = [ "passes/passes.h", @@ -51,6 +52,7 @@ cc_library( ":lift_quantizable_spots_as_functions_fusion_inc_gen", ":lift_quantizable_spots_as_functions_simple_inc_gen", ":quantization_options_proto_cc", + ":quantization_patterns", ":stablehlo_passes_inc_gen", ":stablehlo_type_utils", ":uniform_quantized_types", @@ -58,15 +60,20 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_op_quant_spec", - "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/platform:path", "//tensorflow/core/tpu:tpu_defs", @@ -88,6 +95,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", @@ -99,9 +107,42 @@ cc_library( "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", ], - # Alwayslink is required for registering the MLIR passes. - # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. - alwayslink = True, +) + +cc_library( + name = "quantization_patterns", + srcs = ["passes/quantization_patterns.cc"], + hdrs = [ + "passes/quantization_patterns.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":bridge_passes", + ":uniform_quantized_types", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:path", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], ) td_library( @@ -109,13 +150,12 @@ td_library( srcs = [ "passes/lift_quantizable_spots_as_functions_fusion.td", "passes/lift_quantizable_spots_as_functions_simple.td", - "passes/utils.td", ], compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call_td_files", "//tensorflow/compiler/mlir/quantization/tensorflow:quant_td_files", - "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", "@llvm-project//mlir:ArithOpsTdFiles", "@llvm-project//mlir:FuncTdFiles", @@ -134,7 +174,10 @@ gentbl_cc_library( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_simple.td", - deps = [":quant_td_files"], + deps = [ + ":quant_td_files", + "//tensorflow/compiler/mlir/quantization/common:quant_td_files", + ], ) gentbl_cc_library( @@ -148,7 +191,10 @@ gentbl_cc_library( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_fusion.td", - deps = [":quant_td_files"], + deps = [ + ":quant_td_files", + "//tensorflow/compiler/mlir/quantization/common:quant_td_files", + ], ) gentbl_cc_library( @@ -222,8 +268,6 @@ cc_library( "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", ], - # Force link to ensure ConvertTFQuantOpsToMHLOPass is registered. - alwayslink = True, ) tf_cc_test( @@ -331,11 +375,8 @@ cc_library( ":fill_quantization_options", ":passes", ":quantization_options_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/core/platform:path", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", ], @@ -487,6 +528,27 @@ tf_proto_library( # ) # copybara:uncomment_end +# OSS only: This target is header-only. Link `quantization_config_proto_cc_impl` only to +# `libtensorflow_framework.so` via `lib_internal_impl`. Do NOT link +# `quantization_config_proto_cc_impl` directly unless the target does not link +# `libtensorflow_framework.so`. +tf_proto_library( + name = "quantization_config_proto", + srcs = ["quantization_config.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "quantization_config_py_pb2", +# api_version = 2, +# visibility = [":internal_visibility_allowlist_package"], +# deps = [":quantization_config_proto"], +# ) +# copybara:uncomment_end + exports_files([ "run_lit.sh", ]) @@ -503,14 +565,20 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", + "//tensorflow/core/ir/types:Dialect", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", + "@local_xla//xla/mlir_hlo:mhlo_passes", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD new file mode 100644 index 00000000000000..5c94eb06e617df --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -0,0 +1,161 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_portable", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", + "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "component", + hdrs = ["component.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_absl//absl/status:statusor", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "io", + srcs = ["io.cc"], + hdrs = ["io.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", + ], +) + +tf_cc_test( + name = "io_test", + srcs = ["io_test.cc"], + deps = [ + ":io", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:types", + ], +) + +cc_library( + name = "graph_def", + srcs = [], + hdrs = ["graph_def.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "graph_def_test", + srcs = ["graph_def_test.cc"], + deps = [ + ":graph_def", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:protobuf", + ], +) + +cc_library( + name = "debugger", + srcs = ["debugger.cc"], + hdrs = ["debugger.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "export", + srcs = ["export.cc"], + hdrs = ["export.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_test( + name = "export_test", + srcs = ["export_test.cc"], + deps = [ + ":export", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:protobuf", + ], +) + +cc_library( + name = "precalibration", + srcs = ["precalibration.cc"], + hdrs = ["precalibration.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":component", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", + ], +) + +tf_cc_test( + name = "precalibration_test", + srcs = ["precalibration_test.cc"], + deps = [ + ":precalibration", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD new file mode 100644 index 00000000000000..946a733c2e7da9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD @@ -0,0 +1,61 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", + "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "min_max_value", + srcs = [], + hdrs = ["min_max_value.h"], + compatible_with = get_compatible_with_portable(), + deps = [], +) + +cc_library( + name = "statistics", + srcs = ["statistics.cc"], + hdrs = ["statistics.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "assign_ids", + srcs = ["assign_ids.cc"], + hdrs = ["assign_ids.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:graph_def", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "assign_ids_test", + srcs = ["assign_ids_test.cc"], + deps = [ + ":assign_ids", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:protobuf", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.cc new file mode 100644 index 00000000000000..31e990bbcf20a5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.cc @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::calibrator::CalibratorSingleton; + +} // namespace + +void AssignIdsToCustomAggregatorOps(GraphDef& graph_def) { + MutateNodeDefs(graph_def, [](NodeDef& node_def) { + if (node_def.op() == "CustomAggregator") { + const int64_t new_id = CalibratorSingleton::IssueNewId(); + (*node_def.mutable_attr())["id"].set_s(absl::StrCat(new_id)); + } + }); +} + +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h new file mode 100644 index 00000000000000..6feaa81cc16ce4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_ASSIGN_IDS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_ASSIGN_IDS_H_ + +#include "tensorflow/core/framework/graph.pb.h" + +namespace stablehlo::quantization { + +// Assigns unique ids to each CustomAggregator op found in `graph_def`. The +// ids are set to the `id` attribute. The ids are used during the calibration +// step to identify the collected quantization statistics for each +// CustsomAggregator op. +void AssignIdsToCustomAggregatorOps(tensorflow::GraphDef& graph_def); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_ASSIGN_IDS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids_test.cc new file mode 100644 index 00000000000000..488315a32271b5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h" + +#include +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep for tsl::protobuf + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::GraphDef; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::SizeIs; +using ::tsl::protobuf::TextFormat; + +TEST(AssignIdsTest, IdsAddedToCustomAggregatorOps) { + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + node { op: "CustomAggregator" name: "foo" } + )pb", + &graph_def)); + + AssignIdsToCustomAggregatorOps(graph_def); + + ASSERT_THAT(graph_def.node(), SizeIs(1)); + EXPECT_TRUE(graph_def.node()[0].attr().contains("id")); + EXPECT_THAT(graph_def.node()[0].attr().at("id").s(), Not(IsEmpty())); +} + +TEST(AssignIdsTest, IdsNotAddedForNonCustomAggregatorOps) { + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + node { op: "NotCustomAggregator" name: "bar" } + )pb", + &graph_def)); + + AssignIdsToCustomAggregatorOps(graph_def); + + ASSERT_THAT(graph_def.node(), SizeIs(1)); + EXPECT_FALSE(graph_def.node()[0].attr().contains("id")); +} + +} // namespace +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h new file mode 100644 index 00000000000000..5302bad49dd5e8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_MIN_MAX_VALUE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_MIN_MAX_VALUE_H_ + +#include + +namespace stablehlo::quantization { + +// Represents the (min, max) value pair, representing the range of values after +// calibrating for quantization. +using MinMaxValue = std::pair; + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_MIN_MAX_VALUE_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc new file mode 100644 index 00000000000000..6fe1f8d9cd4f8f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.cc @@ -0,0 +1,71 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::calibrator::CalibrationStatistics; +using ::tensorflow::calibrator::CalibratorSingleton; +using ::tensorflow::quantization::CalibrationOptions; +using ::tensorflow::quantization::PyFunctionLibrary; + +} // namespace + +absl::Status AddCalibrationStatistics( + GraphDef& graph_def, const CalibrationOptions& calibration_options, + const PyFunctionLibrary& py_function_library) { + absl::Status status = absl::OkStatus(); + MutateNodeDefs(graph_def, [&py_function_library, &calibration_options, + &status](NodeDef& node_def) { + if (node_def.op() != "CustomAggregator") return; + const std::string& id = node_def.attr().at("id").s(); + std::optional statistics = + CalibratorSingleton::GetStatistics(id); + if (statistics == std::nullopt) { + status = absl::InternalError( + absl::StrFormat("Calibrated data does not exist. Cannot find " + "statistics. value for id: %s", + id)); + return; + } + + const auto [min_value, max_value] = + py_function_library.GetCalibrationMinMaxValue(*statistics, + calibration_options); + CalibratorSingleton::ClearData(id); + + (*node_def.mutable_attr())["min"].set_f(min_value); + (*node_def.mutable_attr())["max"].set_f(max_value); + }); + return status; +} + +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h new file mode 100644 index 00000000000000..c1a551806f287c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ + +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace stablehlo::quantization { + +// Adds calibrated min / max values to CustomAggregator nodes in `graph_def`. +// The min and max values will be added to the "min" and "max" attributes, +// respectively. `calibration_options` provides the strategy to retrieve min and +// max values. +absl::Status AddCalibrationStatistics( + tensorflow::GraphDef& graph_def, + const tensorflow::quantization::CalibrationOptions& calibration_options, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h new file mode 100644 index 00000000000000..a1ddb5cb4688ff --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// Component is a public abstraction for StableHLO Quantizer that represents the +// most basic unit of action applied to the StableHLO graph. Derived classes +// should override the `Run` method to implement the action. +class Component { + public: + virtual ~Component() = default; + + // Runs the action to the StableHLO graph, passed by the `module_op`. `config` + // should provide information necessary to configure the action's behavior. + virtual absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) = 0; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc new file mode 100644 index 00000000000000..4588d5f00a7523 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::NodeDef; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::DebuggerOptions; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; + +} // namespace + +void EnableDebugging( + ExportedModel& exported_model, const DebuggerOptions& debugger_options, + const PyFunctionLibrary& py_function_library, + const absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& signature_def_map) { + // Enable `DumpTensor` nodes in `graph_def`. DumpTensor is disabled by + // default to avoid logging data during calibration. + MutateNodeDefs(*exported_model.mutable_graph_def(), [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["enabled"].set_b(true); + } + }); + + if (debugger_options.debugger_type() == + DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL) { + // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. + // TODO: b/296916287 - Create a separate function for saving unquantized + // dump model. + py_function_library.SaveExportedModel( + debugger_options.unquantized_dump_model_path(), exported_model, + src_saved_model_path, tags, signature_def_map); + + // Update the `DumpTensor` ops' file name in `graph_def`. + MutateNodeDefs(*exported_model.mutable_graph_def(), [](NodeDef& node_def) { + if (node_def.op() == "DumpTensor") { + (*node_def.mutable_attr())["file_name"].set_s( + "quantized_tensor_data.pb"); + } + }); + } +} + +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h new file mode 100644 index 00000000000000..6bb427ecbdf1fd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace stablehlo::quantization { + +// Enables debugging on `exported_model` by updating the `DumpTensor` ops. +// +// Saves the current model to `debugger_options.unquantized_dump_model_path()` +// if the debugger type is `DEBUGGER_TYPE_WHOLE_MODEL`. This is required because +// in whole-model debugging mode the `DumpTensor` ops for the unquantized +// tensors are only inserted in the unquantized model whereas `DumpTensor` ops +// for the quantized tensors are only inserted in the quantized model. Both +// models are required to be able to dump both quantized and unquantized tensors +// and compare them offline. +void EnableDebugging( + tensorflow::quantization::ExportedModel& exported_model, + const tensorflow::quantization::DebuggerOptions& debugger_options, + const tensorflow::quantization::PyFunctionLibrary& py_function_library, + absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& + signature_def_map); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/export.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/export.cc new file mode 100644 index 00000000000000..bf90f153bf0f91 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/export.cc @@ -0,0 +1,89 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/export.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" + +namespace stablehlo::quantization { + +using ::tensorflow::AssetFileDef; +using ::tensorflow::GraphDef; +using ::tensorflow::SaverDef; +using ::tensorflow::quantization::ExportedModel; + +ExportedModel CreateExportedModel( + GraphDef&& graph_def, const absl::string_view init_node_name, + const absl::string_view checkpoint_dir, + const std::optional saver_def, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs) { + ExportedModel exported_model{}; + *exported_model.mutable_graph_def() = graph_def; + exported_model.set_init_node_name(std::string(init_node_name)); + exported_model.set_checkpoint_dir(std::string(checkpoint_dir)); + + exported_model.mutable_function_aliases()->insert(function_aliases.begin(), + function_aliases.end()); + + for (const AssetFileDef& asset_file_def : asset_file_defs) { + *exported_model.mutable_asset_file_defs()->Add() = asset_file_def; + } + + if (saver_def != std::nullopt) { + *exported_model.mutable_saver_def() = *std::move(saver_def); + } + + return exported_model; +} + +// TODO: b/315746734 - Test this function using a test-only pass. +void AddExportPasses(mlir::PassManager& pm, + const bool duplicate_shape_determining_constants) { + if (duplicate_shape_determining_constants) { + pm.addNestedPass( + mlir::quant::CreateDuplicateShapeDeterminingConstantsPass()); + } + + pm.addPass(mlir::quant::CreateInsertMainFunctionPass()); + pm.addPass(mlir::quant::CreateLiftHashTableOpsAsArgsPass()); + pm.addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm.addPass(mlir::CreateBreakUpIslandsPass()); + pm.addPass(mlir::quant::CreateMergeInitializerFunctionOpsToMainPass()); + pm.addPass(mlir::quant::CreateMergeSaveFunctionOpsToMainPass()); + pm.addNestedPass( + mlir::quant::CreateMergeDuplicateResourceOpsPass()); + + // Used to clean up the "tf._noinliner" attribute that is previously used to + // prevent certain functions from being inlined (see + // `MarkFunctionsNoinlinePass`). InlinerPass must not come after this pass. + pm.addPass(mlir::TF::CreateStripNoinlineAttributePass()); +} + +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/export.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/export.h new file mode 100644 index 00000000000000..9c5117cf97e4c5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/export.h @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_EXPORT_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" + +namespace stablehlo::quantization { + +// Suffix string for the module export step. Used for debugging. +constexpr absl::string_view kExportStepSuffix = "_export"; + +// Options when running passes for exporting an MLIR ModuleOp. +struct ExportOptions { + // If set to `true`, it runs `DuplicateShapeDeterminingConstantsPass` before + // lowering to tf_executor dialect. + bool duplicate_shape_determining_constants = true; + + // If set to `true`, unfreezes constants into variables and saves them to a + // checkpoint file. Setting this to `true` is an experimental feature that has + // no stability guarantees. + bool unfreeze_constants = false; + + // Path to the directory where checkpoint files are saved. + std::string checkpoint_dir = ""; + + // Name used to identify the ModuleOp this is exporting. Only used for + // debugging and does not modify the behavior of the export. + std::string debug_name = "stablehlo_quant"; +}; + +// Factory function for `ExportedModel`. +[[nodiscard]] tensorflow::quantization::ExportedModel CreateExportedModel( + tensorflow::GraphDef&& graph_def, absl::string_view init_node_name, + absl::string_view checkpoint_dir, + std::optional saver_def, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs); + +// Adds passes for transforming the MLIR module op so that it can be exported +// back to GraphDef. Roughly, this consists of: +// 1) Inserting the @main function, which will become the main Graph. +// 2) Duplicating shape-determining constants. +// 3) Converting TF dialect -> tf_executor dialect. +// 4) Adding initializer function's ops into @main function for correct +// resource initialization when loading the exported model. +// +// Duplicating shape-determining constants is required to place constants that +// affect the shape of a tensor to be placed in the TPU graph instead of in the +// CPU graph, when the graph gets converted for TPU inference. This allows these +// constants to be known at XLA compilation time. +void AddExportPasses(mlir::PassManager& pm, + bool duplicate_shape_determining_constants); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/export_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/export_test.cc new file mode 100644 index 00000000000000..b6749c6621de31 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/export_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/export.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::AssetFileDef; +using ::tensorflow::GraphDef; +using ::tensorflow::SaverDef; +using ::tensorflow::quantization::ExportedModel; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; + +TEST(CreateExportedModelTest, CreateExportedModelBasicFieldsSet) { + GraphDef graph_def{}; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(node { name: "foo" })pb", &graph_def)); + + const ExportedModel exported_model = + CreateExportedModel(std::move(graph_def), "init_node_name", + "checkpoint_dir", /*saver_def=*/std::nullopt, + /*function_aliases=*/{}, /*asset_file_defs=*/{}); + ASSERT_THAT(exported_model.graph_def().node(), SizeIs(1)); + EXPECT_THAT(exported_model.graph_def().node()[0].name(), StrEq("foo")); + + EXPECT_THAT(exported_model.init_node_name(), StrEq("init_node_name")); + EXPECT_THAT(exported_model.checkpoint_dir(), StrEq("checkpoint_dir")); + EXPECT_FALSE(exported_model.has_saver_def()); + EXPECT_THAT(exported_model.function_aliases(), IsEmpty()); + EXPECT_THAT(exported_model.asset_file_defs(), IsEmpty()); +} + +TEST(CreateExportedModelTest, CreateExportedModelWithAddedFunctionAliases) { + const ExportedModel exported_model = CreateExportedModel( + GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", + /*saver_def=*/std::nullopt, + /*function_aliases=*/{{"func1", "alias1"}, {"func2", "alias2"}}, + /*asset_file_defs=*/{}); + ASSERT_THAT(exported_model.function_aliases(), SizeIs(2)); + EXPECT_TRUE(exported_model.function_aliases().contains("func1")); + EXPECT_THAT(exported_model.function_aliases().at("func1"), StrEq("alias1")); + EXPECT_TRUE(exported_model.function_aliases().contains("func2")); + EXPECT_THAT(exported_model.function_aliases().at("func2"), StrEq("alias2")); +} + +TEST(CreateExportedModelTest, CreateExportedModelWithAddedAssetFileDefs) { + AssetFileDef asset1; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(filename: "fname1")pb", &asset1)); + + AssetFileDef asset2; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(filename: "fname2")pb", &asset2)); + + const ExportedModel exported_model = CreateExportedModel( + GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", + /*saver_def=*/std::nullopt, /*function_aliases=*/{}, + /*asset_file_defs=*/{asset1, asset2}); + ASSERT_THAT(exported_model.asset_file_defs(), SizeIs(2)); + EXPECT_THAT(exported_model.asset_file_defs()[0].filename(), StrEq("fname1")); + EXPECT_THAT(exported_model.asset_file_defs()[1].filename(), StrEq("fname2")); +} + +TEST(CreateExportedModelTest, CreateExportedModelWithAddedSaverDef) { + SaverDef saver_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(filename_tensor_name: "my_file")pb", &saver_def)); + + const ExportedModel exported_model = CreateExportedModel( + GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", saver_def, + /*function_aliases=*/{}, /*asset_file_defs=*/{}); + EXPECT_THAT(exported_model.saver_def().filename_tensor_name(), "my_file"); +} + +} // namespace +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h new file mode 100644 index 00000000000000..5796b18e65d632 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ + +#include + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace stablehlo::quantization { + +// Mutates all `NodeDef`s in `graph_def` by applying `func`. It modifies the +// top-level `NodeDef`s as well as all `NodeDef`s in the function library. +// `func` should accept a `NodeDef` reference. +template >> +void MutateNodeDefs(tensorflow::GraphDef& graph_def, FuncT&& func) { + for (tensorflow::NodeDef& node_def : *graph_def.mutable_node()) { + func(node_def); + } + + for (tensorflow::FunctionDef& function_def : + *graph_def.mutable_library()->mutable_function()) { + for (tensorflow::NodeDef& node_def : *function_def.mutable_node_def()) { + func(node_def); + } + } +} + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc new file mode 100644 index 00000000000000..58796acc4231bf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h" + +#include +#include +#include "tensorflow/core/framework/node_def.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace stablehlo::quantization { +namespace { + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; + +TEST(GraphDefTest, MutateNodeDefsMutatesTopLevelNodeDefs) { + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString(R"pb( + node { name: "foo" } + )pb", + &graph_def)); + MutateNodeDefs(graph_def, + [](NodeDef& node_def) { node_def.set_name("bar"); }); + + ASSERT_THAT(graph_def.node(), SizeIs(1)); + EXPECT_THAT(graph_def.node()[0].name(), StrEq("bar")); +} + +TEST(GraphDefTest, MutateNodeDefsMutatesFunctionNodeDefs) { + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + library { function { node_def { name: "foo" } } } + )pb", + &graph_def)); + + MutateNodeDefs(graph_def, + [](NodeDef& node_def) { node_def.set_name("bar"); }); + + ASSERT_THAT(graph_def.library().function(), SizeIs(1)); + ASSERT_THAT(graph_def.library().function()[0].node_def(), SizeIs(1)); + EXPECT_THAT(graph_def.library().function()[0].node_def()[0].name(), + StrEq("bar")); +} + +} // namespace +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc new file mode 100644 index 00000000000000..16a1013ae25166 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" + +namespace stablehlo::quantization::io { + +absl::StatusOr GetLocalTmpFileName(tsl::Env* const env) { + std::string tmp_fname{}; + if (!env->LocalTempFilename(&tmp_fname)) { + return absl::InternalError("Failed to create tmp file name."); + } + + return tmp_fname; +} + +absl::StatusOr GetLocalTmpFileName() { + return GetLocalTmpFileName(tsl::Env::Default()); +} + +absl::StatusOr CreateTmpDir(tsl::Env* const env) { + TF_ASSIGN_OR_RETURN(std::string tmp_dir, GetLocalTmpFileName(env)); + + if (!env->RecursivelyCreateDir(tmp_dir).ok()) { + return absl::InternalError( + absl::StrFormat("Failed to create tmp dir: '%s'", tmp_dir)); + } + + return tmp_dir; +} + +absl::StatusOr CreateTmpDir() { + // The overloaded function uses the default env. + return CreateTmpDir(tsl::Env::Default()); +} + +} // namespace stablehlo::quantization::io diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h new file mode 100644 index 00000000000000..bf17ba641f9da5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ + +#include + +#include "absl/status/statusor.h" +#include "tsl/platform/env.h" + +namespace stablehlo::quantization::io { + +// Generates a unique local tmp file name. This function only generates the name +// (path) and doesn't actually creates the file. +absl::StatusOr GetLocalTmpFileName(tsl::Env* env); + +// Generates a unique local tmp file name. This function only generates the name +// (path) and doesn't actually creates the file. The default environment +// `tsl::Env::Default` is used to generate the name. +absl::StatusOr GetLocalTmpFileName(); + +// Creates a temporary directory on an environment defined by the implementation +// of `tsl::Env` and returns its path. Returns an InternalError status if +// failed. +absl::StatusOr CreateTmpDir(tsl::Env* env); + +// Creates a temporary directory and returns its path. Returns an InternalError +// status if failed. The file system used will be the default environment +// returned by `tsl::Env::Default`. +absl::StatusOr CreateTmpDir(); + +} // namespace stablehlo::quantization::io + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc new file mode 100644 index 00000000000000..b5cee2fc492f85 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" + +#include +#include +#include + +#include +#include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/status.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/types.h" + +namespace stablehlo::quantization::io { +namespace { + +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +// A test-only derived class of `tsl::Env` which is broken. Used to cause +// failure for the `CreateTmpDir` function. Each of the overridden member +// functions implements a dummy functionality just to be able to create an +// instance of this class. +class TestEnvBrokenFileSystem : public tsl::Env { + public: + TestEnvBrokenFileSystem() = default; + + bool MatchPath(const tsl::string& path, const tsl::string& pattern) override { + return false; + } + + void SleepForMicroseconds(int64_t micros) override {} + + tsl::string GetRunfilesDir() override { return tsl::string("dummy_path"); } + + int32_t GetCurrentThreadId() override { return 0; } + + tsl::Thread* StartThread(const tsl::ThreadOptions& thread_options, + const tsl::string& name, + absl::AnyInvocable fn) override { + return nullptr; + } + + bool GetCurrentThreadName(tsl::string* name) override { return false; } + + void SchedClosure(absl::AnyInvocable closure) override {} + + void SchedClosureAfter(int64_t micros, + absl::AnyInvocable closure) override {} + + absl::Status LoadDynamicLibrary(const char* library_filename, + void** handle) override { + return tsl::OkStatus(); + } + + absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) override { + return tsl::OkStatus(); + } + + tsl::string FormatLibraryFileName(const tsl::string& name, + const tsl::string& version) override { + return tsl::string("dummy_path"); + } + + // This is the part that would break the `CreateTmpDir` function because it + // fails to provide a valid file system. + absl::Status GetFileSystemForFile(const std::string& fname, + tsl::FileSystem** result) override { + return absl::InternalError("Broken file system"); + } + + private: + void GetLocalTempDirectories(std::vector* list) override { + list->push_back("/tmp"); + } +}; + +// Represents an environment with broken file system and no available local tmp +// directories. +class TestEnvBrokenFileSystemAndNoLocalTempDirs + : public TestEnvBrokenFileSystem { + private: + // This is the part that essentially breaks the `GetLocalTmpFileName` function + // because it doesn't provide any available temp dirs. + void GetLocalTempDirectories(std::vector* list) override {} +}; + +TEST(IoTest, GetLocalTmpFileNameGivesValidFileName) { + absl::StatusOr tmp_file_name = GetLocalTmpFileName(); + + ASSERT_THAT(tmp_file_name, IsOk()); + EXPECT_THAT(*tmp_file_name, Not(IsEmpty())); +} + +TEST(IoTest, GetLocalTmpFileNameWhenNoTempDirsReturnsInternalError) { + TestEnvBrokenFileSystemAndNoLocalTempDirs broken_env; + absl::StatusOr tmp_file_name = GetLocalTmpFileName(&broken_env); + + EXPECT_THAT(tmp_file_name, + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Failed to create tmp file name"))); +} + +TEST(IoTest, CreateTmpDirReturnsValidTmpPath) { + absl::StatusOr tmp_dir = CreateTmpDir(); + + ASSERT_THAT(tmp_dir, IsOk()); + + auto* const env = tsl::Env::Default(); + EXPECT_THAT(env->FileExists(*tmp_dir), IsOk()); +} + +TEST(IoTest, CreateTmpDirWhenInvalidPathReturnsInternalError) { + TestEnvBrokenFileSystem test_env{}; + absl::StatusOr tmp_dir = CreateTmpDir(&test_env); + + EXPECT_THAT(tmp_dir, StatusIs(absl::StatusCode::kInternal, + HasSubstr("Failed to create tmp dir"))); +} + +} // namespace +} // namespace stablehlo::quantization::io diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.cc new file mode 100644 index 00000000000000..7bc1233bf35e04 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.cc @@ -0,0 +1,52 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.h" + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" +#include "tsl/platform/errors.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::quantization::RunPasses; + +// Name of the post-training quantization pre-calibration step. Used for +// debugging purposes. +constexpr absl::string_view kQuantPtqPreCalibrationStepName = + "quant_ptq_pre_calibration"; + +} // namespace + +absl::StatusOr PreCalibrationComponent::Run( + ModuleOp module_op, const QuantizationConfig& config) { + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kQuantPtqPreCalibrationStepName, + /*add_passes_func=*/ + [this](mlir::PassManager& pm) { + AddQuantizePtqPreCalibrationStablehloPasses(pm, calibration_options_); + }, + ctx_, module_op)); + return module_op; +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.h new file mode 100644 index 00000000000000..8a0d90935825df --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.h @@ -0,0 +1,57 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PRECALIBRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PRECALIBRATION_H_ + +#include + +#include "absl/log/die_if_null.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs pre-calibration graph transformation as part of post-training +// static-range quantization. + +// The resulting `ModuleOp` contains `TF::CustomAggregatorOp`s for collecting +// quantization statistics, along with `TF::XlaCallModuleOp`s that correspond to +// lifted quantizable functions. +class PreCalibrationComponent : public Component { + public: + PreCalibrationComponent( + MLIRContext* ctx, + tensorflow::quantization::CalibrationOptions calibration_options) + : ctx_(*ABSL_DIE_IF_NULL(ctx)), // Crash OK + calibration_options_(std::move(calibration_options)) {} + + absl::StatusOr Run( + ModuleOp, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + MLIRContext& ctx_; + // TODO: b/315747711 - Allow `QuantizationConfig` to express calibration + // options and remove this field. + tensorflow::quantization::CalibrationOptions calibration_options_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PRECALIBRATION_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration_test.cc new file mode 100644 index 00000000000000..7a0440d9c461d2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.h" + +#include + +#include +#include +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/status_matchers.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::mlir::quant::common::QuantizationTestBase; +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::quantization::CalibrationOptions; +using ::testing::Contains; +using ::testing::SizeIs; +using ::testing::StartsWith; +using ::testing::StrEq; +using ::tsl::testing::IsOk; + +// Matches an operation whose `getSymName` equals `name`. +MATCHER_P(HasSymName, name, "") { + auto non_const_arg = const_cast>(arg); + *result_listener << "where the name is " << non_const_arg.getSymName().str(); + return non_const_arg.getSymName() == name; +} + +// Matches an operation that has a StringAttr whose name is `name` and value +// matches `value_matcher`. +MATCHER_P2(HasStringAttr, name, value_matcher, + absl::StrCat(negation ? "doesn't have" : "has", + "string attribute: ", name, ", with desirable value")) { + auto non_const_arg = const_cast>(arg); + return non_const_arg->template hasAttrOfType(name) && + ExplainMatchResult( + value_matcher, + non_const_arg->template getAttrOfType(name).str(), + result_listener); +} + +// TODO: b/315746734 - Use test-only passes for in-depth and easier testing. +class PreCalibrationComponentTest : public QuantizationTestBase {}; + +TEST_F(PreCalibrationComponentTest, + HasCustomAggregatorOpAndQuantizableFuncForSimpleDotGeneral) { + PreCalibrationComponent component(&ctx_, CalibrationOptions()); + OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {} { + func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> attributes {} { + %0 = stablehlo.constant dense<1.0> : tensor<4x3xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + } + )mlir"); + + absl::StatusOr pre_calibration_result = + component.Run(*module_op, QuantizationConfig()); + + EXPECT_THAT(pre_calibration_result, IsOk()); + + SmallVector func_ops; + for (auto func_op : pre_calibration_result->getOps()) { + func_ops.push_back(func_op); + } + ASSERT_THAT(func_ops, SizeIs(1)); + EXPECT_THAT(func_ops, Contains(HasSymName("main"))); + + // Tests that there is a XlaCallModuleOp that is a serialized quantizable + // function. + SmallVector xla_call_module_ops; + for (auto xla_call_module_op : func_ops[0].getOps()) { + xla_call_module_ops.push_back(xla_call_module_op); + } + ASSERT_THAT(xla_call_module_ops, SizeIs(2)); + EXPECT_THAT( + xla_call_module_ops, + Contains(HasStringAttr("_tfl_quant_trait", StrEq("fully_quantizable")))); + EXPECT_THAT(xla_call_module_ops, + Contains(HasStringAttr("_original_entry_function", + StartsWith("composite_dot_general_fn")))); + + // Tests that there are CustomAggregatorOps inserted. + SmallVector custom_aggregator_ops; + for (auto custom_aggregator_op : + func_ops[0].getOps()) { + custom_aggregator_ops.push_back(custom_aggregator_op); + } + EXPECT_THAT(custom_aggregator_ops, SizeIs(2)); +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD new file mode 100644 index 00000000000000..d3bf62dfce4923 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -0,0 +1,29 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", + ], + licenses = ["notice"], +) + +cc_library( + name = "stablehlo_op_quant_spec", + srcs = [ + "stablehlo_op_quant_spec.cc", + ], + hdrs = ["stablehlo_op_quant_spec.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc new file mode 100644 index 00000000000000..1a20e3d6d995f8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -0,0 +1,104 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant::stablehlo { + +std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (auto call_op = dyn_cast_or_null(op)) { + auto entry_function = + call_op->getAttrOfType("_entry_function"); + StringRef function_name = entry_function.getValue(); + if (!function_name.startswith("composite_")) { + return spec; + } + if (function_name.contains("conv")) { + spec->coeff_op_quant_dim[1] = 3; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("dot_general")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("dot")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } + for (auto quantizable_operand : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(quantizable_operand.first); + } + } + return spec; +} + +std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { + auto scale_spec = std::make_unique(); + if (llvm::isa(op)) { + scale_spec->has_same_scale_requirement = true; + } + return scale_spec; +} + +bool IsOpQuantizableStableHlo(Operation* op) { + if (mlir::isa(op)) { + // Constant ops do not have QuantizableResult attribute but can be + // quantized. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + if (GetStableHloQuantScaleSpec(op)->has_same_scale_requirement) { + return true; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + return attr_enforced_quantizable; +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h new file mode 100644 index 00000000000000..c898a99c08f68f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Returns StableHLO quantization specs for an op. +std::unique_ptr GetStableHloOpQuantSpec(Operation* op); + +// Returns quantization scale specs (fixed output, same scale) for a StableHLO +// op. +std::unique_ptr GetStableHloQuantScaleSpec(Operation* op); + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 2ff1ba9200261d..e51f55d14aeecf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -53,78 +53,8 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -// This helper function create ops to requantize `input` tensor and returns the -// output tensor. Clamping is done if output integer bit-width < 32. -// -// Requantization is essentially dequantize --> quantize. -// -// Dequantize: (input - zp) * scale -// Quantize: input / scale + zp -// -// Hence, -// output = (input - input_zp) * input_scale / output_scale + output_zp -// -// This is simplified as: -// output = input * merged_scale + merged_zp -// where: -// merged_zp = output_zp - input_zp * merged_scale. -// merged_scale = input_scale / output_scale. -Value Requantize(mlir::OpState op, Value input, - UniformQuantizedType input_quantized_type, - UniformQuantizedType output_quantized_type, - TensorType output_tensor_type, - ConversionPatternRewriter &rewriter) { - // Skip requantization when input and result have the same type. - if (input_quantized_type == output_quantized_type) { - return rewriter.create(op->getLoc(), output_tensor_type, - input); - } - - double merged_scale_fp = - input_quantized_type.getScale() / output_quantized_type.getScale(); - Value merged_scale = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(merged_scale_fp))); - - auto float_tensor_type = - input.getType().cast().clone(rewriter.getF32Type()); - Value output_float = - rewriter.create(op->getLoc(), float_tensor_type, input); - - output_float = rewriter.create( - op->getLoc(), float_tensor_type, output_float, merged_scale, nullptr); - - // Add merged_zp only when it is non-zero. - double merged_zp_fp = output_quantized_type.getZeroPoint() - - input_quantized_type.getZeroPoint() * merged_scale_fp; - if (merged_zp_fp != 0) { - Value merged_zp = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(merged_zp_fp))); - output_float = rewriter.create( - op->getLoc(), float_tensor_type, output_float, merged_zp, nullptr); - } - - // Clamp output if the output integer bit-width <32. - if (output_tensor_type.getElementType().cast().getWidth() < 32) { - Value quantization_min = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr(static_cast( - output_quantized_type.getStorageTypeMin()))); - Value quantization_max = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr(static_cast( - output_quantized_type.getStorageTypeMax()))); - // Clamp results by [quantization_min, quantization_max]. - output_float = rewriter.create( - op->getLoc(), float_tensor_type, quantization_min, output_float, - quantization_max); - } - - output_float = rewriter.create( - op->getLoc(), float_tensor_type, output_float); - return rewriter.create(op->getLoc(), output_tensor_type, - output_float); -} - +// TODO: b/311218165 - consider extract this to common utils and better ways to +// handle polymorphism. using QuantType = std::variant; FailureOr GetQuantType(Type type) { @@ -139,6 +69,22 @@ FailureOr GetQuantType(Type type) { } } +bool IsPerTensorType(QuantType quant_type) { + return std::holds_alternative(quant_type); +} + +bool IsPerChannelType(QuantType quant_type) { + return std::holds_alternative(quant_type); +} + +UniformQuantizedType GetPerTensorType(QuantType quant_type) { + return std::get(quant_type); +} + +UniformQuantizedPerAxisType GetPerChannelType(QuantType quant_type) { + return std::get(quant_type); +} + // Extract scale and zero point info from input quant type info. void GetQuantizationParams(OpBuilder &builder, Location loc, QuantType quant_type, Value &scales, @@ -161,7 +107,7 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, } else { auto &quant_per_channel_type = std::get(quant_type); - llvm::SmallVector scales_vec; + SmallVector scales_vec; for (auto scale : quant_per_channel_type.getScales()) scales_vec.push_back(scale); scales = builder.create( @@ -172,7 +118,7 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, builder.getF32Type()), scales_vec)); if (output_zero_point_in_fp) { - llvm::SmallVector zero_points_vec; + SmallVector zero_points_vec; for (auto zero_point : quant_per_channel_type.getZeroPoints()) zero_points_vec.push_back(zero_point); zero_points = builder.create( @@ -183,7 +129,7 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, builder.getF32Type()), zero_points_vec)); } else { - llvm::SmallVector zero_points_vec; + SmallVector zero_points_vec; for (auto zero_point : quant_per_channel_type.getZeroPoints()) zero_points_vec.push_back(zero_point); zero_points = builder.create( @@ -241,6 +187,147 @@ Type GetQuantStorageType(Type type) { } } +Type GetQuantStorageType(QuantType type) { + if (IsPerTensorType(type)) { + return GetPerTensorType(type).getStorageType(); + } else { + return GetPerChannelType(type).getStorageType(); + } +} + +Value ApplyMergedScalesAndZps(OpBuilder &builder, Location loc, + QuantType input_quant_type, + QuantType output_quant_type, + Value input_float_tensor) { + // Use single merged scale and merged zp if both input and output are + // per-tensor quantized. Otherwise use a vector. + if (IsPerTensorType(input_quant_type) && IsPerTensorType(output_quant_type)) { + UniformQuantizedType input_per_tensor_tyep = + GetPerTensorType(input_quant_type); + UniformQuantizedType output_per_tensor_tyep = + GetPerTensorType(output_quant_type); + double merged_scale_fp = + input_per_tensor_tyep.getScale() / output_per_tensor_tyep.getScale(); + auto merged_scale = builder.create( + loc, builder.getF32FloatAttr(static_cast(merged_scale_fp))); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_scale, + /*broadcast_dimensions=*/nullptr); + // Add merged_zp only when it is non-zero. + double merged_zp_fp = + output_per_tensor_tyep.getZeroPoint() - + input_per_tensor_tyep.getZeroPoint() * merged_scale_fp; + if (merged_zp_fp != 0) { + Value merged_zp = builder.create( + loc, builder.getF32FloatAttr(static_cast(merged_zp_fp))); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_zp, /*broadcast_dimensions=*/nullptr); + } + } else { + int64_t channel_size = + IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getScales().size() + : GetPerChannelType(input_quant_type).getScales().size(); + int64_t quantized_dimension = + IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getQuantizedDimension() + : GetPerChannelType(input_quant_type).getQuantizedDimension(); + SmallVector merged_scale_double, merged_zp_double; + merged_scale_double.resize(channel_size); + merged_zp_double.resize(channel_size); + for (int i = 0; i < channel_size; ++i) { + merged_scale_double[i] = + (IsPerChannelType(input_quant_type) + ? GetPerChannelType(input_quant_type).getScales()[i] + : GetPerTensorType(input_quant_type).getScale()) / + (IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getScales()[i] + : GetPerTensorType(output_quant_type).getScale()); + merged_zp_double[i] = + (IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getZeroPoints()[i] + : GetPerTensorType(output_quant_type).getZeroPoint()) - + (IsPerChannelType(input_quant_type) + ? GetPerChannelType(input_quant_type).getZeroPoints()[i] + : GetPerTensorType(input_quant_type).getZeroPoint()) * + merged_scale_double[i]; + } + SmallVector merged_scale_float(merged_scale_double.begin(), + merged_scale_double.end()), + merged_zp_float(merged_zp_double.begin(), merged_zp_double.end()); + + auto broadcast_dims = DenseIntElementsAttr::get( + RankedTensorType::get({1}, builder.getI64Type()), + {quantized_dimension}); + Value merged_scale = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channel_size}, builder.getF32Type()), + merged_scale_float)); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_scale, broadcast_dims); + if (llvm::any_of(merged_zp_float, [](double zp) { return zp != 0; })) { + Value merged_zp = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channel_size}, builder.getF32Type()), + merged_zp_float)); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_zp, broadcast_dims); + } + } + return input_float_tensor; +} + +// This helper function create ops to requantize `input` tensor and returns the +// output tensor. Clamping is done if output integer bit-width < i32. It assumes +// that if both input and output tensor are per-channel quantized, they have the +// same quantization axis. +// +// Requantization is essentially dequantize --> quantize. +// +// Dequantize: (input - zp) * scale +// Quantize: input / scale + zp +// +// Hence, +// output = (input - input_zp) * input_scale / output_scale + output_zp +// +// This is simplified as: +// output = input * merged_scale + merged_zp +// where: +// merged_zp = output_zp - input_zp * merged_scale. +// merged_scale = input_scale / output_scale. +Value Requantize(mlir::OpState op, Value input, QuantType input_quant_type, + QuantType output_quant_type, TensorType output_tensor_type, + ConversionPatternRewriter &rewriter) { + // Skip requantization when input and result have the same type. + if (input_quant_type == output_quant_type) { + return rewriter.create(op->getLoc(), output_tensor_type, + input); + } + + auto float_tensor_type = output_tensor_type.clone(rewriter.getF32Type()); + Value output_float = + rewriter.create(op->getLoc(), float_tensor_type, input); + + output_float = + ApplyMergedScalesAndZps(rewriter, op->getLoc(), input_quant_type, + output_quant_type, output_float); + + // Clamp output if the output integer bit-width <32. + if (output_tensor_type.getElementType().cast().getWidth() < 32) { + Value quantization_min, quantization_max; + GetQuantizationStorageInfo(rewriter, op->getLoc(), output_quant_type, + quantization_min, quantization_max); + // Clamp results by [quantization_min, quantization_max]. + output_float = rewriter.create( + op->getLoc(), quantization_min, output_float, quantization_max); + } + + output_float = rewriter.create( + op->getLoc(), float_tensor_type, output_float); + return rewriter.create(op->getLoc(), output_tensor_type, + output_float); +} + class ConvertUniformQuantizeOp : public OpConversionPattern { public: @@ -255,10 +342,24 @@ class ConvertUniformQuantizeOp if (succeeded(quant_type)) { return matchAndRewriteQuantize(op, adaptor, rewriter, *quant_type); } - } else if (input_element_type.isa()) { - return matchAndRewriteRequantize(op, adaptor, rewriter); + } else if (input_element_type.isa()) { + auto input_quant_type = GetQuantType(input_element_type); + auto output_quant_type = GetQuantType(op.getResult().getType()); + if (succeeded(input_quant_type) && succeeded(output_quant_type)) { + if (IsPerChannelType(*input_quant_type) && + IsPerChannelType(*output_quant_type) && + GetPerChannelType(*input_quant_type).getQuantizedDimension() != + GetPerChannelType(*output_quant_type).getQuantizedDimension()) { + op->emitError("Cannot requantize while changing quantization_axis"); + return failure(); + } + return matchAndRewriteRequantize(op, adaptor, rewriter, + *input_quant_type, *output_quant_type); + } } - return rewriter.notifyMatchFailure(op, "Unsupported input element type."); + op->emitError("Unsupported input element type."); + return failure(); } LogicalResult matchAndRewriteQuantize(mhlo::UniformQuantizeOp op, @@ -298,16 +399,14 @@ class ConvertUniformQuantizeOp LogicalResult matchAndRewriteRequantize( mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto input_quantized_type = getElementTypeOrSelf(op.getOperand().getType()) - .cast(); - auto output_quantized_type = getElementTypeOrSelf(op.getResult().getType()) - .cast(); + ConversionPatternRewriter &rewriter, QuantType input_quant_type, + QuantType output_quant_type) const { rewriter.replaceOp( - op, Requantize(op, adaptor.getOperand(), input_quantized_type, - output_quantized_type, + op, Requantize(op, adaptor.getOperand(), input_quant_type, + output_quant_type, + /*output_tensor_type=*/ op.getResult().getType().cast().clone( - output_quantized_type.getStorageType()), + GetQuantStorageType(output_quant_type)), rewriter)); return success(); } @@ -357,18 +456,18 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto lhs_element_type = - op.getLhs().getType().getElementType().dyn_cast(); - auto rhs_element_type = - op.getRhs().getType().getElementType().dyn_cast(); - auto result_element_type = op.getResult() - .getType() - .getElementType() - .dyn_cast(); + auto lhs_quant_type = + GetQuantType(getElementTypeOrSelf(op.getLhs().getType())); + auto rhs_quant_type = + GetQuantType(getElementTypeOrSelf(op.getRhs().getType())); + auto res_quant_type = + GetQuantType(getElementTypeOrSelf(op.getResult().getType())); // We only handle cases where lhs, rhs and results all have quantized // element type. - if (!lhs_element_type || !rhs_element_type || !result_element_type) { + if (failed(lhs_quant_type) || IsPerChannelType(*lhs_quant_type) || + failed(rhs_quant_type) || IsPerChannelType(*rhs_quant_type) || + failed(res_quant_type) || IsPerChannelType(*res_quant_type)) { op->emitError( "AddOp requires the same quantized element type for all operands and " "results"); @@ -384,17 +483,17 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { // TODO: b/260280919 - Consider avoiding conversion to int32. Value lhs = adaptor.getLhs(); Value lhs_int32_tensor = - Requantize(op, lhs, lhs_element_type, result_element_type, + Requantize(op, lhs, *lhs_quant_type, *res_quant_type, res_int32_tensor_type, rewriter); Value rhs = adaptor.getRhs(); Value rhs_int32_tensor = - Requantize(op, rhs, rhs_element_type, result_element_type, + Requantize(op, rhs, *rhs_quant_type, *res_quant_type, res_int32_tensor_type, rewriter); Value zero_point = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_element_type.getZeroPoint()))); + GetPerTensorType(*res_quant_type).getZeroPoint()))); // Now the lhs and rhs have been coverted to the same scale and zps. // Given: @@ -411,24 +510,26 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { Value res_int32 = rewriter.create( op->getLoc(), res_int32_tensor_type, add_result, zero_point, nullptr); - if (result_element_type.getStorageType().isInteger(32)) { + if (GetQuantStorageType(*res_quant_type).isInteger(32)) { // For i32, clamping is not needed. rewriter.replaceOp(op, res_int32); } else { // Clamp results by [quantization_min, quantization_max] when storage type // is not i32. Value result_quantization_min = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_element_type.getStorageTypeMin()))); + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + GetPerTensorType(*res_quant_type).getStorageTypeMin()))); Value result_quantization_max = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_element_type.getStorageTypeMax()))); + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + GetPerTensorType(*res_quant_type).getStorageTypeMax()))); res_int32 = rewriter.create( op->getLoc(), res_int32_tensor_type, result_quantization_min, res_int32, result_quantization_max); // Convert results back to result storage type. auto res_final_tensor_type = - res_int32_tensor_type.clone(result_element_type.getStorageType()); + res_int32_tensor_type.clone(GetQuantStorageType(*res_quant_type)); rewriter.replaceOpWithNewOp(op, res_final_tensor_type, res_int32); } @@ -445,12 +546,12 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { // dimensions are defined in // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general. struct DotLikeDimensionNumbers { - ArrayRef lhs_batching_dims; - ArrayRef lhs_spatial_dims; - ArrayRef lhs_contracting_dims; - ArrayRef rhs_batching_dims; - ArrayRef rhs_spatial_dims; - ArrayRef rhs_contracting_dims; + SmallVector lhs_batching_dims; + SmallVector lhs_spatial_dims; + SmallVector lhs_contracting_dims; + SmallVector rhs_batching_dims; + SmallVector rhs_spatial_dims; + SmallVector rhs_contracting_dims; }; // A shared matchAndRewrite implementation for dot-like hybrid quantized @@ -503,7 +604,7 @@ LogicalResult matchAndRewriteDotLikeHybridOp( Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, Value tensor, const int64_t other_tensor_zp, - ArrayRef reduction_dims) { + SmallVector reduction_dims) { // This function calculates part of the zero-point-offset by using // mhlo::Reduce to sum over the contracting dims of the tensor, and then // multiply by zp of the other tensor. @@ -512,7 +613,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, // Calculate the output tensor shape. This is input tensor dims minus // contracting dims. auto ranked_tensor = tensor.getType().cast(); - llvm::SmallVector output_dims; + SmallVector output_dims; for (int64_t i = 0; i < ranked_tensor.getRank(); ++i) { if (absl::c_count(reduction_dims, i) == 0) { output_dims.push_back(ranked_tensor.getDimSize(i)); @@ -581,7 +682,7 @@ Value CalculateDynamicOutputDims(OpBuilder &builder, Location loc, Value lhs, // Calculate each output dim and concatenate into a 1D tensor. // Output dims are batching dims, spatial dims, LHS result dims, RHS result // dims. - llvm::SmallVector output_dims; + SmallVector output_dims; for (int64_t i = 0; i < lhs_shape.getRank(); ++i) { if (absl::c_count(dims.lhs_batching_dims, i) != 0) { output_dims.push_back(GetDimValue(builder, loc, lhs, lhs_shape, i)); @@ -612,8 +713,8 @@ Value CalculateDynamicOutputDims(OpBuilder &builder, Location loc, Value lhs, Value BroadcastZpContribution(OpBuilder &builder, Location loc, Value zp_contribution, - llvm::ArrayRef reduction_dims, - llvm::ArrayRef batching_dims, + ArrayRef reduction_dims, + ArrayRef batching_dims, int64_t non_batching_starting_idx, TensorType output_tensor_type, Value &output_dims_value, Value lhs, Value rhs, @@ -623,7 +724,7 @@ Value BroadcastZpContribution(OpBuilder &builder, Location loc, // broadcast. auto zp_contribution_rank = zp_contribution.getType().cast().getRank(); - llvm::SmallVector broadcast_dims; + SmallVector broadcast_dims; broadcast_dims.resize(zp_contribution_rank, 0); // Result tensor will have batching dims first, then LHS result dims, then // RHS result dims. So non-batching result dims index doesn't start from 0. @@ -643,9 +744,9 @@ Value BroadcastZpContribution(OpBuilder &builder, Location loc, broadcast_dims[idx] = result_batching_idx++; } } - // Use broadcast_in_dim or dyanmic_broadcast_in_dim based on input shape + // Use broadcast_in_dim or dyanmic_broadcast_in_dim based on output shape // dynamism. - if (zp_contribution.getType().cast().hasStaticShape()) { + if (output_tensor_type.cast().hasStaticShape()) { zp_contribution = builder.create( loc, output_tensor_type, zp_contribution, DenseIntElementsAttr::get( @@ -677,9 +778,8 @@ Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, Value output_dims_value = nullptr; // Calculate LHS contribution when RHS zp is non-zero. if (rhs_zp != 0) { - llvm::SmallVector reduction_dims = - llvm::to_vector(llvm::concat(dims.lhs_spatial_dims, - dims.lhs_contracting_dims)); + SmallVector reduction_dims = to_vector(llvm::concat( + dims.lhs_spatial_dims, dims.lhs_contracting_dims)); Value lhs_zp_contribution = CreateZeroPointPartialOffset(builder, loc, lhs, rhs_zp, reduction_dims); // Broadcast lhs ZP contribution to result tensor shape. @@ -691,9 +791,8 @@ Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, } // Calculate RHS contribution when LHS zp is non-zero. if (lhs_zp != 0) { - llvm::SmallVector reduction_dims = - llvm::to_vector(llvm::concat(dims.rhs_spatial_dims, - dims.rhs_contracting_dims)); + SmallVector reduction_dims = to_vector(llvm::concat( + dims.rhs_spatial_dims, dims.rhs_contracting_dims)); Value rhs_zp_contribution = CreateZeroPointPartialOffset(builder, loc, rhs, lhs_zp, reduction_dims); // Broadcast rhs ZP contribution to result tensor shape. @@ -762,11 +861,13 @@ Value CreateDotLikeKernel(OpBuilder &builder, Location loc, Value &rhs, ArrayRef attrs) { // We only handle the case where RHS zp is zero. - auto original_padding = op.getPaddingAttr().getValues(); - // Explicitly pad LHS with zp and update LHS value. - llvm::SmallVector new_attrs(attrs); - if (llvm::any_of(original_padding, [](int64_t x) { return x != 0; })) { + SmallVector new_attrs(attrs); + if (op.getPadding().has_value() && + llvm::any_of(op.getPaddingAttr().getValues(), + [](int64_t x) { return x != 0; })) { + auto original_padding = op.getPaddingAttr().getValues(); + Value zp = builder.create( loc, DenseIntElementsAttr::get( @@ -779,7 +880,7 @@ Value CreateDotLikeKernel(OpBuilder &builder, Location loc, // mhlo::Convolution. But mhlo::Pad require those for all dimensions. Hence // we add 0 to the beginning and end of the padding vectors. int64_t rank = lhs.getType().cast().getRank(); - llvm::SmallVector padding_low(rank, 0), padding_high(rank, 0), + SmallVector padding_low(rank, 0), padding_high(rank, 0), padding_interior(rank, 0); for (int64_t i = 1; i < rank - 1; ++i) { padding_low[i] = original_padding[i * 2 - 2]; @@ -962,7 +1063,7 @@ class ConvertUniformQuantizedDotOp : public OpConversionPattern { rewriter.getContext(), /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{1}, /*rhsContractingDimensions=*/{0}); - llvm::SmallVector attrs(op->getAttrs()); + SmallVector attrs(op->getAttrs()); attrs.push_back( {StringAttr::get(rewriter.getContext(), "dot_dimension_numbers"), dims}); @@ -997,12 +1098,14 @@ class ConvertUniformQuantizedDotGeneralOp return matchAndRewriteDotLikeOp( op, adaptor, op->getAttrs(), DotLikeDimensionNumbers{ - op.getDotDimensionNumbers().getLhsBatchingDimensions(), + to_vector(op.getDotDimensionNumbers().getLhsBatchingDimensions()), /*lhs_spatial_dims=*/{}, - op.getDotDimensionNumbers().getLhsContractingDimensions(), - op.getDotDimensionNumbers().getRhsBatchingDimensions(), + to_vector( + op.getDotDimensionNumbers().getLhsContractingDimensions()), + to_vector(op.getDotDimensionNumbers().getRhsBatchingDimensions()), /*rhs_spatial_dims=*/{}, - op.getDotDimensionNumbers().getRhsContractingDimensions()}, + to_vector( + op.getDotDimensionNumbers().getRhsContractingDimensions())}, rewriter); } } @@ -1088,7 +1191,7 @@ FailureOr VerifyAndConstructDims( auto res_element_quant_per_channel_type = getElementTypeOrSelf(op.getResult()) .cast(); - llvm::SmallVector scale_ratios( + SmallVector scale_ratios( res_element_quant_per_channel_type.getScales().size()); for (int i = 0; i < scale_ratios.size(); ++i) { scale_ratios[i] = @@ -1106,7 +1209,8 @@ FailureOr VerifyAndConstructDims( } } // lhs_dilation must not exist. - if (llvm::any_of(op.getLhsDilationAttr().getValues(), + if (op.getLhsDilation().has_value() && + llvm::any_of(op.getLhsDilationAttr().getValues(), [](int64_t dilate) { return dilate != 1; })) { op->emitError("lhs_dilation must be 1."); return failure(); @@ -1160,6 +1264,7 @@ class ConvertUniformQuantizedConvolutionOp // This pattern lowers a generic MHLO op for uq->int. // This pattern essentially just performs type change, with no algorithm change. +// TODO: b/310685906 - Add operand/result type validations. class ConvertGenericOp : public ConversionPattern { public: explicit ConvertGenericOp(MLIRContext *ctx) @@ -1169,36 +1274,16 @@ class ConvertGenericOp : public ConversionPattern { Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // This pattern only handle selected ops. - if (!isa(op)) { + if (!isa(op)) { return failure(); } - // Check that all operands and result uq types are the same. - llvm::SmallVector uq_types; - for (auto result_type : op->getResultTypes()) { - auto type = - getElementTypeOrSelf(result_type).dyn_cast(); - if (type) { - uq_types.push_back(type); - } - } - for (auto operand : op->getOperands()) { - auto type = getElementTypeOrSelf(operand.getType()) - .dyn_cast(); - if (type) { - uq_types.push_back(type); - } - } - for (auto type : uq_types) { - if (type != uq_types.front()) { - return failure(); - } - } - // Determine new result type: use storage type for uq types; use original // type otherwise. - llvm::SmallVector new_result_types; + SmallVector new_result_types; for (auto result_type : op->getResultTypes()) { new_result_types.push_back(GetQuantStorageType(result_type)); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index f20d1b3609361e..1987b607392379 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -62,7 +62,9 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -class ConvertTfQuantToMhloIntTest : public ::testing::Test { +using ::testing::Test; + +class ConvertTfQuantToMhloIntTest : public Test { protected: void SetUp() override { DialectRegistry dialects; @@ -281,7 +283,7 @@ class ConvertTfQuantToMhloIntTest : public ::testing::Test { absl::BitGen bitgen_; }; -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAndDequantize) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAndDequantizeToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { %scale = "tf.Const"() { value = dense<0.347> : tensor } : () -> tensor @@ -306,7 +308,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { kProgram, {&arg0}, /*tf_program=*/std::nullopt, /*error_tolerance=*/0.35); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizePerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizePerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %arg0: tensor<10x10xf32>, %scale: tensor<10xf32>, %zp: tensor<10xi32> @@ -330,7 +332,7 @@ func.func @main( /*error_tolerance=*/1.0); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformDequantizePerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformDequantizePerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %arg0: tensor<10x10xi8>, %scale: tensor<10xf32>, %zp: tensor<10xi32> @@ -350,7 +352,7 @@ func.func @main( ExecuteAndCompareResultsWithTfKernel(kProgram, {&arg0, &scale, &zp}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolution) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>) -> tensor<1x9x9x10xi32> { %input_scale = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor @@ -389,7 +391,8 @@ func.func @main(%input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>) -> ten ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionPerChannel) { +TEST_F(ConvertTfQuantToMhloIntTest, + UniformQuantizeConvolutionPerChannelToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main( %input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>, %scale: tensor<10xf32> @@ -428,7 +431,8 @@ func.func @main( ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter, &scale}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionHybrid) { +TEST_F(ConvertTfQuantToMhloIntTest, + UniformQuantizeConvolutionHybridToValidGraph) { constexpr absl::string_view kTfProgram = R"mlir( func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) -> tensor<2x10x10x20xf32> { %filter_scale = "tf.Const"() { value = dense<0.047> : tensor } : () -> tensor @@ -476,7 +480,7 @@ func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) - ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDot) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%input: tensor<8x9xi8>, %filter: tensor<9x10xi8>) -> tensor<8x10xi32> { %input_scale = "tf.Const"() { value = dense<0.588> : tensor } : () -> tensor @@ -513,7 +517,7 @@ func.func @main(%input: tensor<8x9xi8>, %filter: tensor<9x10xi8>) -> tensor<8x10 ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotHybrid) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotHybridToValidGraph) { constexpr absl::string_view kTfProgram = R"mlir( func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x10xf32> { %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () -> tensor @@ -550,7 +554,7 @@ func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x1 ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantize) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizeToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> { %input_scale = "tf.Const"() { value = dense<0.2235> : tensor } : () -> tensor @@ -579,7 +583,131 @@ func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> { ExecuteAndCompareResultsWithTfKernel(kProgram, {&input}); } -TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAdd) { +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerChannelToValidGraph) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main( + %input: tensor<10x10xi8>, %input_scale: tensor<10xf32>, + %input_zp: tensor<10xi32>, %output_scale: tensor<10xf32>, + %output_zp: tensor<10xi32> + ) -> tensor<10x10xi8> { + %0 = "tf.Cast"(%input) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = 1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = 1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x10x!tf_type.qint8>, tensor<10xf32>, tensor<10xi32>, + tensor<10xf32>, tensor<10xi32> + ) -> tensor<10x10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8> + return %2 : tensor<10x10xi8> +})mlir"; + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN( + auto input_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto input_zp, CreateRandomI32Literal({10})); + TF_ASSERT_OK_AND_ASSIGN( + auto output_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto output_zp, CreateRandomI32Literal({10})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&input, &input_scale, &input_zp, &output_scale, &output_zp}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); +} + +TEST_F(ConvertTfQuantToMhloIntTest, + UniformRequantizePerTensorToPerChannelToValidGraph) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main( + %input: tensor<10x10xi8>, %input_scale: tensor, %input_zp: tensor, + %output_scale: tensor<10xf32>, %output_zp: tensor<10xi32> + ) -> tensor<10x10xi8> { + %0 = "tf.Cast"(%input) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = -1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = 1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x10x!tf_type.qint8>, tensor, tensor, + tensor<10xf32>, tensor<10xi32> + ) -> tensor<10x10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8> + return %2 : tensor<10x10xi8> +})mlir"; + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN( + auto input_scale, CreateRandomF32Literal({}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto input_zp, CreateRandomI32Literal({})); + TF_ASSERT_OK_AND_ASSIGN( + auto output_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto output_zp, CreateRandomI32Literal({10})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&input, &input_scale, &input_zp, &output_scale, &output_zp}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); +} + +TEST_F(ConvertTfQuantToMhloIntTest, + UniformRequantizePerChannelToPerTensorToValidGraph) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main( + %input: tensor<10x10xi8>, %input_scale: tensor<10xf32>, + %input_zp: tensor<10xi32>, %output_scale: tensor, %output_zp: tensor + ) -> tensor<10x10xi8> { + %0 = "tf.Cast"(%input) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = 1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = -1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x10x!tf_type.qint8>, tensor<10xf32>, tensor<10xi32>, + tensor, tensor + ) -> tensor<10x10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8> + return %2 : tensor<10x10xi8> +})mlir"; + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN( + auto input_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto input_zp, CreateRandomI32Literal({10})); + TF_ASSERT_OK_AND_ASSIGN( + auto output_scale, CreateRandomF32Literal({}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto output_zp, CreateRandomI32Literal({})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&input, &input_scale, &input_zp, &output_scale, &output_zp}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); +} + +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAddToValidGraph) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%lhs: tensor<10x10xi32>, %rhs: tensor<10x10xi32>) -> tensor<10x10xi32> { %lhs_scale = "tf.Const"() { value = dense<0.518> : tensor } : () -> tensor diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc index 856bbd49930341..9a5e6c53d3d1d6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc @@ -38,11 +38,12 @@ using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; using ::tensorflow::monitoring::testing::CellReader; +using ::testing::Test; static constexpr char kMetricsName[] = "/tensorflow/core/tf2xla/tf_quant_op_count"; -class LegalizeTfTypesTest : public ::testing::Test { +class LegalizeTfTypesTest : public Test { protected: void CreateModule(const char* module_string) { DialectRegistry mlir_registry; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc index 1fd1a0b6bab721..4c20b6bebdcdad 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc @@ -34,7 +34,9 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -class LegalizeTFQuantTest : public ::testing::Test { +using ::testing::Test; + +class LegalizeTFQuantTest : public Test { protected: void TestBridgeLowering(llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc index 361d98c7775abe..2825195addea12 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc @@ -57,7 +57,7 @@ bool IsQuantType(Type type) { IsTFQintType(element_type); } -bool IsMhloUniformQuantizedOp(Operation* op) { +bool IsMhloUniformQuantizedOp(Operation& op) { return llvm::isa(op); } @@ -68,7 +68,7 @@ void VerifyQuantLegalization::runOnOperation() { // Verify all uq and qint types are lowered. if (llvm::any_of(op->getOperandTypes(), IsQuantType) || llvm::any_of(op->getResultTypes(), IsQuantType) || - IsTFUniformQuantizedOp(op) || IsMhloUniformQuantizedOp(op)) { + IsTFUniformQuantizedOp(op) || IsMhloUniformQuantizedOp(*op)) { op->emitOpError("is illegal as it is a UQ op or contains uq/qint types"); LOG(ERROR) << "Found illegal op containing uq/qint type: " << op->getName().getStringRef().str(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 383f2430c94eee..6f13634b317aa4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -29,9 +29,8 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" -// TODO - b/303543789: Remove TF Quantizer util dependency. +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" namespace mlir::quant::stablehlo { @@ -42,7 +41,7 @@ namespace { // TODO - b/303543789: Move the helper functions below to a separate util. // Fetches the default or null attribute, used for pattern matching. -static Attribute DefaultOrNullAttr(OpBuilder& builder, Attribute& attr) { +Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { if (!attr) { return builder.getStringAttr(kNullAttributeValue); } @@ -51,7 +50,7 @@ static Attribute DefaultOrNullAttr(OpBuilder& builder, Attribute& attr) { // Checks whether the value of a constant equals the given float, regardless // of the tensor dimension. -static bool FloatValueEquals(const Attribute& attr, double value) { +bool FloatValueEquals(const Attribute& attr, const double value) { auto fp_attr = attr.dyn_cast_or_null(); if (!fp_attr) return false; @@ -101,7 +100,7 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { } // Remove all attr_map attributes. - module_op.walk([&](Operation* op) { op->removeAttr(kAttrMapAttribute); }); + module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); }); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td index 116037d9130df2..0e7706c8d550a1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td @@ -20,10 +20,8 @@ include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "stablehlo/dialect/StablehloOps.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td" -include "tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td" -// TODO - b/303543789: Remove TF Quantizer util dependency. +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" //===----------------------------------------------------------------------===// // Pattern rules for lifting ops with bias as functions diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td index fc5af302e794a8..9bc337b8d46949 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td @@ -19,10 +19,8 @@ include "mlir/Dialect/Arith/IR/ArithOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "stablehlo/dialect/StablehloOps.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td" -include "tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td" -// TODO - b/303543789: Remove TF Quantizer util dependency. +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" //===----------------------------------------------------------------------===// // Pattern rules for lifting ops as functions @@ -56,4 +54,4 @@ def LiftDotGeneral : Pat< (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), - [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; \ No newline at end of file + [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h index 0b05069b265989..4973c515d96a58 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -28,7 +28,7 @@ namespace mlir::quant::stablehlo { // Creates a `QuantizePass` that quantizes ops according to surrounding qcast / // dcast ops. -std::unique_ptr> CreateQuantizePass( +std::unique_ptr> CreateQuantizePass( const quant::QuantizationSpecs& quantization_specs); // Creates a pass that quantizes weight component of StableHLO graph. @@ -39,7 +39,7 @@ std::unique_ptr> CreateQuantizeWeightPass( // Creates an instance of the StableHLO dialect PrepareQuantize pass without any // arguments. Preset method of SRQ is set to the quantization option by default. std::unique_ptr> CreatePrepareQuantizePass( - bool enable_per_channel_quantization = true, int bit_width = 8); + bool enable_per_channel_quantization = false, int bit_width = 8); // Adds generated pass default constructors or options definitions. #define GEN_PASS_DECL diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 52dca7897ea05d..c69e72120538b3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -67,7 +67,7 @@ def ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : Pass<"stablehlo- }]; } -def QuantizePass : Pass<"stablehlo-quantize", "mlir::func::FuncOp"> { +def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { let summary = "Applies static-range quantization on ops."; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", @@ -103,3 +103,13 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function "TF::TensorFlowDialect", ]; } + +def UnwrapXlaCallModuleOpPass : Pass<"stablehlo-unwrap-xla-call-module-op", "ModuleOp"> { + let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns."; + let dependentDialects = ["TF::TensorFlowDialect"]; +} + +def PopulateShapePass : Pass<"populate-shape", "ModuleOp"> { + let summary = "Populate output shape with known information for CustomAggregatorOp and XlaCallModuleOp."; + let dependentDialects = ["TF::TensorFlowDialect"]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc new file mode 100644 index 00000000000000..0d4f0594f5c7d8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc @@ -0,0 +1,144 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/core/ir/types/dialect.h" + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_POPULATESHAPEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { + +class PopulateShapeForCustomAggregatorOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::CustomAggregatorOp op, TF::CustomAggregatorOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto input_shape_type = op.getInput().getType().dyn_cast(); + auto output_shape_type = op.getOutput().getType(); + + if (!input_shape_type.isa()) { + input_shape_type = adaptor.getInput().getType(); + } + + if (input_shape_type.isa() && + !output_shape_type.isa() && + TF::HasCompatibleElementTypes(input_shape_type, output_shape_type)) { + auto new_op = rewriter.create( + op->getLoc(), /*output=*/input_shape_type, + /*args=*/adaptor.getInput(), + /*Id=*/op.getId()); + new_op->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, new_op); + return success(); + } + return failure(); + } +}; + +class PopulateShapeForXlaCallModuleOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::XlaCallModuleOp op, TF::XlaCallModuleOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) { + op->emitError("XlaCallModuleOp doesn't have 1 output."); + return failure(); + } + // Assume XlaCallModuleOp only has 1 output. + auto output_shape_type = op->getResultTypes()[0]; + if (!output_shape_type.isa()) { + auto output_shape_attr = op.getSout()[0].dyn_cast(); + if (!output_shape_attr.hasRank()) { + return failure(); + } + auto new_output_shape_type = tensorflow::GetTypeFromTFTensorShape( + output_shape_attr.getShape(), + getElementTypeOrSelf(op.getResultTypes()[0])); + auto new_op = rewriter.create( + op->getLoc(), /*output=*/new_output_shape_type, + /*args=*/adaptor.getOperands(), + /*version=*/op.getVersionAttr(), + /*module=*/op.getModuleAttr(), + /*Sout=*/op.getSoutAttr()); + new_op->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, new_op); + return success(); + } + return failure(); + } +}; + +class PopulateShapePass + : public impl::PopulateShapePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PopulateShapePass) + + explicit PopulateShapePass() = default; + + private: + void runOnOperation() override; +}; + +void PopulateShapePass::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addDynamicallyLegalOp([](Operation *op) { + auto custom_aggregator_op = llvm::dyn_cast(op); + return custom_aggregator_op.getInput().getType().isa() && + custom_aggregator_op.getOutput().getType().isa(); + }); + target.addDynamicallyLegalOp([](Operation *op) { + if (op->getNumResults() != 1) return true; + return op->getResultTypes()[0].isa(); + }); + + patterns + .add( + context); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + return signalPassFailure(); + } +} +} // namespace + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 6da27d9e3c2823..24d15dfd6688d5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -15,11 +15,9 @@ limitations under the License. // Copied and modified from // //third_party/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc // This transformation pass applies quantization propagation on TF dialect. -#include #include #include -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -35,6 +33,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -134,50 +133,6 @@ class ConvertArithConstToStablehloConstOp } }; -std::unique_ptr GetStableHLOOpQuantSpec(Operation* op) { - auto spec = std::make_unique(); - if (auto call_op = dyn_cast_or_null(op)) { - auto entry_function = - call_op->getAttrOfType("_entry_function"); - StringRef function_name = entry_function.getValue(); - if (!function_name.startswith("composite_")) { - return spec; - } - if (function_name.contains("conv")) { - spec->coeff_op_quant_dim[1] = 3; - if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; - } - } else if (function_name.contains("dot_general")) { - spec->coeff_op_quant_dim[1] = -1; - if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; - } - } else if (function_name.contains("dot")) { - spec->coeff_op_quant_dim[1] = -1; - if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; - } - } - for (auto quantizable_operand : spec->coeff_op_quant_dim) { - spec->quantizable_operands.insert(quantizable_operand.first); - } - } - return spec; -} - -std::unique_ptr GetStableHLOQuantScaleSpec(Operation* op) { - auto scale_spec = std::make_unique(); - if (llvm::isa( - op)) { - scale_spec->has_same_scale_requirement = true; - } - return scale_spec; -} - void PrepareQuantizePass::runOnOperation() { func::FuncOp func = getOperation(); MLIRContext* ctx = func.getContext(); @@ -185,8 +140,8 @@ void PrepareQuantizePass::runOnOperation() { // The function might contain more stats ops than required, and it will // introduce requantize if the calibration stats have conflicts. This tries to // remove all the redundant stats ops. - RemoveRedundantStatsOps(func, GetStableHLOOpQuantSpec, - GetStableHLOQuantScaleSpec); + RemoveRedundantStatsOps(func, GetStableHloOpQuantSpec, + GetStableHloQuantScaleSpec); RewritePatternSet patterns(ctx); // Convert quant stats to int8 quantization parameters. @@ -209,7 +164,7 @@ void PrepareQuantizePass::runOnOperation() { // values (tensors). ApplyQuantizationParamsPropagation( func, /*is_signed=*/true, bit_width_, !enable_per_channel_quantization_, - GetStableHLOOpQuantSpec, GetStableHLOQuantScaleSpec, + GetStableHloOpQuantSpec, GetStableHloQuantScaleSpec, /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false); // Restore constants as stablehlo::ConstantOp. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc new file mode 100644 index 00000000000000..6cd3be0cdc572c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -0,0 +1,432 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +#define DEBUG_TYPE "populate-quantization-patterns" + +namespace mlir::quant::stablehlo { + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConvolutionOp; +using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::DynamicBroadcastInDimOp; +using ::mlir::stablehlo::UniformQuantizeOp; + +constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kEntryFuncAttrName = "_entry_function"; + +// Returns true if `type` is a TensorType with quantized elements. +bool IsQuantizedTensorType(const Type type) { + return type.isa() && + type.cast().getElementType().isa(); +} + +// Returns true if an op has adjacent bias or activation that can be fused +// together into the quantization function. +// TODO: b/307620428 - Consider using matchAndRewrite to check and apply +// patterns at the same time. Also add check for fusible activation or +// fusible patterns with dynamic shape. +bool HasFusibleQuantizationPattern(Operation& op) { + if (isa(op.getNextNode())) { + return true; + } + return false; +} + +// Returns dynamically broadcasted user op of an input op. Returns null if +// the op is used multiple times or the user op is not dynamically broadcasted. +// Dynamic shapes usually has the following pattern. In the example below, +// the input operand would be stablehlo.gemm_style op, and return value would +// be stablehlo.add op. +// +// ``` +// %2 = stablehlo.gemm_style(%0, %1) +// %3 = shape.shape_of %2 +// %4 = stablehlo.dynamic_broadcast_in_dims %cst, %3 +// %5 = stablehlo.add %2, %4 +// ``` +Operation* GetDynamicallyBroadcastedUserOp(Operation& op) { + if (!op.hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Target op is used multiple times and will not be checked " + "for dynamic shape case.\n"); + return nullptr; + } + Operation& shapeof_op = *op.getNextNode(); + if (!isa(shapeof_op)) { + return nullptr; + } + Operation& broadcast_in_dims_op = *shapeof_op.getNextNode(); + if (!isa(broadcast_in_dims_op)) { + return nullptr; + } + return broadcast_in_dims_op.getNextNode(); +} + +// Checks if all inputs and outputs are quantized. +bool HasQuantizedOperandOrOutput(Operation& call_op) { + SmallVector arg_types; + for (const Value arg : call_op.getOperands()) { + arg_types.push_back(arg.getType()); + } + + SmallVector output_types; + for (const Value output : call_op.getResults()) { + output_types.push_back(output.getType()); + } + + return absl::c_all_of(arg_types, IsQuantizedTensorType) && + absl::c_all_of(output_types, IsQuantizedTensorType); +} + +// Gets the corresponding quantized function name from the given function name. +// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" +std::string GetQuantizedFunctionName(const StringRef func_name) { + return Twine(kQuantizedFuncPrefix) + .concat(func_name.rsplit(kCompositeFuncPrefix).second) + .str(); +} + +// Returns true if `xla_call_module_op` is quantized. To be considered +// quantized, it should meet three conditions: +// 1. At least one of the inputs or outputs should be a uniform quantized type. +// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. +// 3. It should also have the `kEntryFuncAttrName` attribute, which points to +// the function that `xla_call_module_op` represents. +bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { + return HasQuantizedOperandOrOutput(*xla_call_module_op) && + xla_call_module_op->hasAttr(kQuantTraitAttrName) && + xla_call_module_op->hasAttr(kEntryFuncAttrName); +} + +// Returns the entry function, i.e. the callee of `xla_call_module_op`. +func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, + SymbolTable symbol_table) { + const auto entry_function_symbol_ref = + xla_call_module_op->getAttrOfType(kEntryFuncAttrName); + + return dyn_cast_or_null( + symbol_table.lookup(entry_function_symbol_ref.getValue())); +} + +// Replaces the function type of `entry_func_op` to a quantized one, matching +// the input and output types of `xla_call_module_op`. +void SetQuantizedFunctionType(PatternRewriter& rewriter, + func::FuncOp entry_func_op, + TF::XlaCallModuleOp xla_call_module_op) { + SmallVector arg_types; + SmallVector arg_locs; + for (const Value arg : xla_call_module_op.getArgs()) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + + SmallVector output_types; + for (const Value output : xla_call_module_op.getOutput()) { + output_types.push_back(output.getType()); + } + + entry_func_op.setFunctionType( + rewriter.getFunctionType(arg_types, output_types)); + + // Replace argument types and locs. + Block& entry = entry_func_op->getRegion(0).front(); + for (auto [arg, arg_type, arg_loc] : + llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { + arg.setType(arg_type); + arg.setLoc(arg_loc); + } +} + +// Creates a UniformQuantize op and sets it as return op. +void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, + func::FuncOp entry_func_op, + const Type func_result_type) { + // Add i32 -> i8 requantization. + UniformQuantizeOp uniform_quant_op = rewriter.create( + op.getLoc(), func_result_type, op.getResults()); + cast(entry_func_op.getBody().front().getTerminator()) + .setOperand(0, uniform_quant_op); +} + +// An interface representing patterns that quantizes an entry function's body. +// The entry function's signatures should have already been quantized at the +// point of rewriting. +class EntryFuncBodyQuantizationPattern { + public: + virtual ~EntryFuncBodyQuantizationPattern() = default; + + // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At + // this point `entry_func_op`'s signature has not been reset with quantized + // types. + virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; + + // Rewrites the `entry_func_op`'s body. + virtual void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const = 0; +}; + +// Gemm Style Op: glossary/gemm. +template +// Match for all gemm_style op and check for possible fusions. +LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { + // function must have input, filter, and optionally bias. + auto& operations = entry_func_op.getBody().front().getOperations(); + if (operations.size() != 2 && operations.size() != 3) { + return failure(); + } + if (!isa(operations.front())) { + return failure(); + } else if (GetDynamicallyBroadcastedUserOp(operations.front())) { + LLVM_DEBUG(llvm::dbgs() + << "Currently gemm style ops quantization only supports static " + " shapes.\n"); + return failure(); + } else if (!isa( + operations.front().getResult(0).getType())) { + return failure(); + } + return success(); +} + +// Gemm Style Op: glossary/gemm. +template +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { + // Update the output type of the gemm_style op. + GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); + + const Type input_type = entry_func_op.getArgumentTypes()[0]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + const double input_scale = + getElementTypeOrSelf(input_type).cast().getScale(); + const double filter_scale = + getElementTypeOrSelf(filter_type).cast().getScale(); + const double result_scale = input_scale * filter_scale; + + // Define the intermediate output type, which is an i32 quantized type. + // This is intermediate because the final output type of the entry_func_op + // should be an i8 quantized type. + const UniformQuantizedType gemm_style_quantized_element_type = + CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), + *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + Value gemm_style_op_result = gemm_style_op->getResult(0); + auto gemm_style_op_result_type = + gemm_style_op_result.getType().cast(); + const ArrayRef gemm_style_shape = + gemm_style_op_result_type.getShape(); + + const TensorType new_gemm_style_op_result_type = + gemm_style_op_result_type.cloneWith(gemm_style_shape, + gemm_style_quantized_element_type); + gemm_style_op_result.setType(new_gemm_style_op_result_type); + + rewriter.setInsertionPointAfter(gemm_style_op); + + Operation& next_op = *gemm_style_op->getNextNode(); + // If an op is used multiple times, do not apply quantization of fused + // patterns to prevent removal of dependee ops. + const bool should_quantize_without_fusion = + HasFusibleQuantizationPattern(*gemm_style_op.getOperation()) && + !gemm_style_op->hasOneUse(); + + // TODO: b/307620428 - Add support for dynamic shapes. + if (should_quantize_without_fusion || !isa(next_op)) { + // no bias + CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, + func_result_type); + return; + } + // bias fusion + Value bias_op = next_op.getOperand(1); + Value add_op_result = next_op.getResult(0); + const auto add_op_result_type = + add_op_result.getType().cast(); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, gemm_style_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = + rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); +} + +// Quantizes the entry function's body containing a `DotGeneralOp`. +class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeDotGeneralOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } +}; + +// Quantizes the entry function's body containing a `ConvolutionOp`. +class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeConvolutionOpPattern() = default; + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp(entry_func_op, rewriter); + } +}; + +// Converts `entry_func_op` to be quantized according to the respective +// inputs and outputs of `xla_call_module_op` that are possibly quantized. It +// signature (type) is reset to match that of `xla_call_module_op`. +// `entry_func_body_quantization_pattern` rewrites the function's body, based on +// the new signature. +void QuantizeEntryFuncOp( + MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); + + body_rewrite_pattern.rewrite(entry_func_op, rewriter); + + // Rename the function to be clear that the function has been quantized. + const std::string quantized_function_name = + GetQuantizedFunctionName(entry_func_op.getSymName()); + entry_func_op.setSymName(quantized_function_name); +} + +// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee +// is expected to remain unquantized (thus having a signature mismatch), and it +// is also quantized accordingly. +void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { + ModuleOp module_op = xla_call_module_op->getParentOfType(); + SymbolTable symbol_table(module_op); + + func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); + QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, + body_rewrite_pattern); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.setInsertionPoint(xla_call_module_op); + rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, + xla_call_module_op.getArgs()); +} + +// Pattern that mainly does two things: +// +// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. +// 2. Quantizes the callee function. +// +// The inputs of this pattern assumes an invalid IR, where even if a +// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) +// not only replaces the input and output tensor types into quantized ones, but +// also rewrites the body with a quantized equivalent. +// +// `FuncBodyRewritePatternT` defines how a function body is quantized and +// rewritten. +template >> +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) + : OpRewritePattern(&ctx) {} + + LogicalResult match(TF::XlaCallModuleOp op) const override { + ModuleOp module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + // Ignore unquantized ops. + if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + + func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); + if (!entry_func_op) { + op->emitError("Failed to find a valid entry function."); + return failure(); + } + + return FuncBodyRewritePatternT().match(entry_func_op); + } + + void rewrite(TF::XlaCallModuleOp xla_call_module_op, + PatternRewriter& rewriter) const override { + ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + *rewriter.getContext(), rewriter, xla_call_module_op, + FuncBodyRewritePatternT()); + } +}; + +} // namespace + +// TODO: b/307620428 - Increase fused op coverage for static range quantization. +void PopulateFusedGemmStylePatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + patterns.add, + XlaCallModuleOpToCallOp>(ctx); +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h new file mode 100644 index 00000000000000..79daa9ce8b48b8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -0,0 +1,379 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir::quant::stablehlo { + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// The concrete pattern, extends from this base pattern, can specify whether it +// allows dynamic range quantized operands and results for the operations in the +// current context. These "DynamicRangeQuantized" operands and results don't +// have quantization parameters propagated to, so will be in float in the +// quantized results. The concrete pattern should define the following two +// functions: +// +// bool AllowDynamicRangeQuantizedOperand(Operation&) const +// bool AllowDynamicRangeQuantizedResult(Operation&) const +// +// Full integer quantization disallows "DynamicRangeQuantized" operands or +// results. Dynamic range quantization allows "DynamicRangeQuantized" operands +// and results. +// +// Implementation of this pattern is mostly copied from QuantizationPattern in +// third_party/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h. +template +class StableHloQuantizationPattern : public RewritePattern { + public: + using BaseType = + StableHloQuantizationPattern; + + explicit StableHloQuantizationPattern( + MLIRContext* context, const mlir::quant::QuantPassSpec& quant_params) + // Set the score to a large number so it is always preferred. + : RewritePattern(RootOpT::getOperationName(), 300, context), + quant_params_(quant_params) {} + + private: + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + llvm::SmallVector quantizing_ops; + + // Collect all the ops to quantize, as the user / producer of the root op. + if constexpr (std::is_same_v) { + if (op->getNumResults() != 1) { + op->emitError("Dequantize op should have exactly one result."); + return failure(); + } + auto users = op->getResult(0).getUsers(); + quantizing_ops.append(users.begin(), users.end()); + } else if constexpr (std::is_same_v) { + if (op->getNumOperands() != 1) { + op->emitError("Quantize op should have exactly one operand."); + return failure(); + } + Value quantize_operand = op->getOperand(0); + if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { + // The input of the quantize op has already been quantized, i.e. + // rescale. + return failure(); + } + DenseFPElementsAttr attr; + if (matchPattern(quantize_operand, m_Constant(&attr))) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { + quantizing_ops.push_back(quantizing_op); + } + } + + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; + CustomMap custom_map = quant_params_.quant_spec.custom_map; + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, we shouldn't rewrite. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizableStableHlo(quantizing_op) && + !static_cast(this)->IsQuantizableCustomOp( + *quantizing_op, custom_map)) { + return failure(); + } + + if (GetStableHloQuantScaleSpec(quantizing_op) + ->has_same_scale_requirement && + !IsConnectedWithQuantizedCompsiteFunction(quantizing_op)) { + return failure(); + } + + // Blocklist op is checked in advance for non-dynamic range quantization + // case. + if (!quant_params_.quant_spec.weight_quantization && + (ops_blocklist.contains( + quantizing_op->getName().getStringRef().str()))) { + return failure(); + } + + if (!nodes_blocklist.empty()) { + if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + std::string sloc = name_loc.getName().str(); + if (!sloc.empty() && + (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { + return failure(); + } + } + } + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (auto operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (operand_type.isa()) { + inputs.push_back(operand); + continue; + } + + auto ele_type = operand.getType().cast().getElementType(); + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none type + // results. + if (result_type.isa()) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + result.getType().cast().getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(*quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state(quantizing_op->getLoc(), + quantizing_op->getName().getStringRef(), inputs, + output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + return success(); + } + + // Checks whether the operation is connnected with a quantized composite + // function. If not, the same-scale op will not be quantized. This decision is + // based on the current assumption that the performance gain of the same-scale + // op itself could not beat the overhead of the quantize and dequantize + // routines need to be added around that op. When the assumption changes, + // this policy might change as well. + bool IsConnectedWithQuantizedCompsiteFunction( + Operation* same_scale_op) const { + for (const auto& operand : same_scale_op->getOperands()) { + auto dq_op = dyn_cast_or_null( + operand.getDefiningOp()); + if (!dq_op) continue; + + Operation* preceding_op = dq_op.getArg().getDefiningOp(); + if (!preceding_op) continue; + + // Check whether the preceding op is a quantized composite function. + if (llvm::isa(preceding_op)) { + auto call_op = llvm::cast(preceding_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the preceding op is a quantized same-scale op. + if (GetStableHloQuantScaleSpec(preceding_op) + ->has_same_scale_requirement) { + for (auto result : preceding_op->getResults()) { + auto element_type = getElementTypeOrSelf(result.getType()); + if (element_type.isa()) { + return true; + } + } + } + } + + for (const auto& result : same_scale_op->getResults()) { + // If the user is the Quantize op, it must be the only user. + if (!result.hasOneUse() || + !llvm::isa(*result.user_begin())) { + continue; + } + + auto q_op = llvm::cast(*result.user_begin()); + for (auto following_op : q_op->getUsers()) { + // Check whether the following op is a quantized composite function. + if (llvm::isa(following_op)) { + auto call_op = llvm::cast(following_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the following op is a quantized same-scale op. + if (GetStableHloQuantScaleSpec(following_op) + ->has_same_scale_requirement) { + for (auto operand : following_op->getOperands()) { + auto element_type = getElementTypeOrSelf(operand.getType()); + if (element_type.isa()) { + return true; + } + } + } + } + } + + return false; + } + + // Checks if op calls a composite function and all the inputs and outputs are + // quantized. + bool IsQuantizedCompositeFunction(TF::XlaCallModuleOp call_op) const { + if (!call_op->hasAttr(kQuantTraitAttrName)) { + return false; + } + + const auto function_name = call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); + if (!function_name || !function_name.getValue().startswith("composite_")) { + return false; + } + + bool has_quantized_types = false; + for (Value input : call_op.getArgs()) { + if (auto type = input.getType().dyn_cast()) { + if (type.getElementType().isa()) { + return false; + } + if (type.getElementType().isa()) { + has_quantized_types = true; + } + } + } + for (Value output : call_op.getOutput()) { + if (auto type = output.getType().dyn_cast()) { + if (type.getElementType().isa()) { + return false; + } + if (type.getElementType().isa()) { + has_quantized_types = true; + } + } + } + return has_quantized_types; + } + + QuantPassSpec quant_params_; +}; + +// Gemm Style Op: glossary/gemm. +// Populates conversion patterns to unfuse batch normalization operations. +void PopulateFusedGemmStylePatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 16e7ad1cfd7010..d629b26b4f0fc8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -20,10 +20,13 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -31,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h" namespace mlir::quant::stablehlo { @@ -42,31 +46,32 @@ namespace { // Base struct for quantization. template struct StableHloQuantizationBase - : public QuantizationPattern { + : public StableHloQuantizationPattern { explicit StableHloQuantizationBase(MLIRContext* ctx, const QuantPassSpec& quant_params) - : QuantizationPattern(ctx, quant_params) {} + : StableHloQuantizationPattern( + ctx, quant_params) {} - static bool IsQuantizableCustomOp(Operation* op, + static bool IsQuantizableCustomOp(Operation& op, const CustomMap& custom_op_map) { return false; } static bool AllowDynamicRangeQuantizedOperand( - Operation* quantized_op, const CustomMap& custom_op_map) { + Operation& quantized_op, const CustomMap& custom_op_map) { return false; } - static bool AllowDynamicRangeQuantizedResult(Operation* quantized_op, + static bool AllowDynamicRangeQuantizedResult(Operation& quantized_op, const CustomMap& custom_op_map) { return false; } - static bool IsWeightOnlyOp(Operation* quantized_op, + static bool IsWeightOnlyOp(Operation& quantized_op, absl::flat_hash_set& ops_blocklist, bool weight_only_quantization, const CustomMap& custom_op_map) { @@ -112,7 +117,7 @@ class QuantizePass : public impl::QuantizePassBase { }; void QuantizePass::runOnOperation() { - func::FuncOp func = getOperation(); + ModuleOp module_op = getOperation(); MLIRContext& ctx = getContext(); NumericVerifySpec numeric_verify_spec; @@ -125,20 +130,21 @@ void QuantizePass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add( &ctx, quant_params); + PopulateFusedGemmStylePatterns(ctx, patterns); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a // best-effort. // TODO: b/305469508 - Make QuantizationPattern converge if there are no // patterns that are rewritable. - func.emitWarning("Failed to converge pattern at QuantizePass."); + module_op.emitWarning("Failed to converge pattern at QuantizePass."); } } } // namespace -std::unique_ptr> CreateQuantizePass( +std::unique_ptr> CreateQuantizePass( const QuantizationSpecs& quantization_specs) { return std::make_unique(quantization_specs); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index cf0c44f779a9ae..01230798bdcf8c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -20,9 +20,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -50,6 +52,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#define DEBUG_TYPE "quantize-composite-functions" + namespace mlir::quant::stablehlo { #define GEN_PASS_DEF_QUANTIZECOMPOSITEFUNCTIONSPASS @@ -58,7 +62,10 @@ namespace mlir::quant::stablehlo { namespace { using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::DynamicBroadcastInDimOp; using ::mlir::stablehlo::UniformQuantizeOp; using ::tensorflow::quantization::RunPassesOnModuleOp; @@ -79,248 +86,6 @@ class QuantizeCompositeFunctionsPass void runOnOperation() override; }; -// Returns true if `type` is a TensorType with quantized elements. -bool IsQuantizedTensorType(const Type type) { - return type.isa() && - type.cast().getElementType().isa(); -} - -// Checks if all inputs and outputs are quantized. -bool HasQuantizedOperandOrOutput(Operation* call_op) { - SmallVector arg_types; - for (const Value arg : call_op->getOperands()) { - arg_types.push_back(arg.getType()); - } - - SmallVector output_types; - for (const Value output : call_op->getResults()) { - output_types.push_back(output.getType()); - } - - return absl::c_all_of(arg_types, IsQuantizedTensorType) && - absl::c_all_of(output_types, IsQuantizedTensorType); -} - -// Get the corresponding quantized function name from the given function name. -// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" -std::string GetQuantizedFunctionName(const StringRef func_name) { - return Twine(kQuantizedFuncPrefix) - .concat(func_name.rsplit(kCompositeFuncPrefix).second) - .str(); -} - -// Returns true if `xla_call_module_op` is quantized. To be considered -// quantized, it should meet three conditions: -// 1. At least one of the inputs or outputs should be a uniform quantized type. -// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. -// 3. It should also have the `kEntryFuncAttrName` attribute, which points to -// the function that `xla_call_module_op` represents. -bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { - return HasQuantizedOperandOrOutput(xla_call_module_op) && - xla_call_module_op->hasAttr(kQuantTraitAttrName) && - xla_call_module_op->hasAttr(kEntryFuncAttrName); -} - -// Returns the entry function, i.e. the callee of `xla_call_module_op`. -func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, - SymbolTable symbol_table) { - auto entry_function_symbol_ref = - xla_call_module_op->getAttrOfType(kEntryFuncAttrName); - - // Don't match if there are no DotGeneralOp. - // if (target_func_op.getOps().empty()) return {}; - return dyn_cast_or_null( - symbol_table.lookup(entry_function_symbol_ref.getValue())); -} - -// Replaces the function type of `entry_func_op` to a quantized one, matching -// the input and output types of `xla_call_module_op`. -void SetQuantizedFunctionType(PatternRewriter& rewriter, - func::FuncOp entry_func_op, - TF::XlaCallModuleOp xla_call_module_op) { - SmallVector arg_types; - SmallVector arg_locs; - for (const Value arg : xla_call_module_op.getArgs()) { - arg_types.push_back(arg.getType()); - arg_locs.push_back(arg.getLoc()); - } - - SmallVector output_types; - for (const Value output : xla_call_module_op.getOutput()) { - output_types.push_back(output.getType()); - } - - entry_func_op.setFunctionType( - rewriter.getFunctionType(arg_types, output_types)); - - // Replace argument types and locs. - Block& entry = entry_func_op->getRegion(0).front(); - for (auto [arg, arg_type, arg_loc] : - llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { - arg.setType(arg_type); - arg.setLoc(arg_loc); - } -} - -// An interface representing patterns that quantizes an entry function's body. -// The entry function's signatures should have already been quantized at the -// point of rewriting. -class EntryFuncBodyQuantizationPattern { - public: - virtual ~EntryFuncBodyQuantizationPattern() = default; - - // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At - // this point `entry_func_op`'s signature has not been reset with quantized - // types. - virtual LogicalResult match(func::FuncOp entry_func_op) const = 0; - - // Rewrites the `entry_func_op`'s body. - virtual void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const = 0; -}; - -// Quantizes the entry function's body containing a `DotGeneralOp`. -class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { - public: - explicit QuantizeDotGeneralOpPattern(MLIRContext& ctx) : ctx_(&ctx) {} - - LogicalResult match(func::FuncOp entry_func_op) const override { - auto& operations = entry_func_op.getBody().front().getOperations(); - return success(operations.size() == 2 && - isa(operations.front())); - } - - void rewrite(func::FuncOp entry_func_op, - PatternRewriter& rewriter) const override { - // Update the output type of the dot_general op. - auto dot_general_op = *entry_func_op.getOps().begin(); - - const Type input_type = entry_func_op.getArgumentTypes()[0]; - const Type rhs_type = entry_func_op.getArgumentTypes()[1]; - const Type func_result_type = entry_func_op.getResultTypes()[0]; - - const double input_scale = getElementTypeOrSelf(input_type) - .cast() - .getScale(); - const double rhs_scale = - getElementTypeOrSelf(rhs_type).cast().getScale(); - - // Define the intermediate output type, which is an i32 quantized type. - // This is intermediate because the final output type of the entry_func_op - // should be an i8 quantized type. - const UniformQuantizedType output_quantized_element_type = - CreateI32F32UniformQuantizedType(dot_general_op->getLoc(), *ctx_, - input_scale * rhs_scale, - /*zero_point=*/0); - - Value dot_general_op_result = dot_general_op->getResult(0); - const auto dot_general_op_result_type = - dot_general_op_result.getType().cast(); - const ArrayRef shape = dot_general_op_result_type.getShape(); - - const TensorType new_dot_general_op_result_type = - dot_general_op_result_type.cloneWith(shape, - output_quantized_element_type); - dot_general_op_result.setType(new_dot_general_op_result_type); - - // Add i32 -> i8 requantization. - rewriter.setInsertionPointAfter(dot_general_op); - auto uniform_quant_op = rewriter.create( - dot_general_op->getLoc(), func_result_type, - dot_general_op->getResults()); - - auto return_op = - cast(entry_func_op.getBody().front().getTerminator()); - return_op.setOperand(0, uniform_quant_op); - } - - private: - MLIRContext* ctx_ = nullptr; -}; - -// Converts `entry_func_op` to be quantized according to the respective -// inputs and outputs of `xla_call_module_op` that are possibly quantized. It -// signature (type) is reset to match that of `xla_call_module_op`. -// `entry_func_body_quantization_pattern` rewrites the function's body, based on -// the new signature. -void QuantizeEntryFuncOp( - MLIRContext& ctx, PatternRewriter& rewriter, - TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); - - body_rewrite_pattern.rewrite(entry_func_op, rewriter); - - // Rename the function to be clear that the function has been quantized. - const std::string quantized_function_name = - GetQuantizedFunctionName(entry_func_op.getSymName()); - entry_func_op.setSymName(quantized_function_name); -} - -// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee -// is expected to remain unquantized (thus having a signature mismatch), and it -// is also quantized accordingly. -void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - MLIRContext& ctx, PatternRewriter& rewriter, - TF::XlaCallModuleOp xla_call_module_op, - const EntryFuncBodyQuantizationPattern& body_rewrite_pattern) { - auto module_op = xla_call_module_op->getParentOfType(); - SymbolTable symbol_table(module_op); - - func::FuncOp entry_func_op = GetEntryFuncOp(xla_call_module_op, symbol_table); - QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, - body_rewrite_pattern); - - // Replace the XlaCallModuleOp with a new CallOp. - rewriter.setInsertionPoint(xla_call_module_op); - rewriter.replaceOpWithNewOp(xla_call_module_op, entry_func_op, - xla_call_module_op.getArgs()); -} - -// Pattern that mainly does two things: -// -// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. -// 2. Quantizes the callee function. -// -// The inputs of this pattern assumes an invalid IR, where even if a -// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) -// not only replaces the input and output tensor types into quantized ones, but -// also rewrites the body with a quantized equivalent. -// -// `FuncBodyRewritePatternT` defines how a function body is quantized and -// rewritten. -template >> -class XlaCallModuleOpToCallOp : public OpRewritePattern { - public: - explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) - : OpRewritePattern(&ctx) {} - - LogicalResult match(TF::XlaCallModuleOp op) const override { - auto module_op = op->getParentOfType(); - SymbolTable symbol_table(module_op); - - // Ignore unquantized ops. - if (!IsQuantizedXlaCallModuleOp(op)) return failure(); - - func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); - if (!entry_func_op) { - op->emitError("Failed to find a valid entry function."); - return failure(); - } - - return FuncBodyRewritePatternT(*getContext()).match(entry_func_op); - } - - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(*getContext())); - } -}; - void QuantizeCompositeFunctionsPass::runOnOperation() { MLIRContext& ctx = getContext(); @@ -334,7 +99,9 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { pm.enableVerifier(false); pm.addNestedPass(CreatePrepareQuantizePass()); - pm.addNestedPass(CreateQuantizePass(quant_specs)); + // QuantizePass modifies FuncOps referenced outside of its given scope + // and therefore requires a module-level context. + pm.addPass(CreateQuantizePass(quant_specs)); pm.addNestedPass(createPostQuantizePass()); ModuleOp module_op = getOperation(); @@ -343,14 +110,6 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { !pm_run_status.ok()) { signalPassFailure(); } - - // TODO - b/307839649: Move this as a separate pass. - RewritePatternSet patterns(&ctx); - patterns.add>(ctx); - - if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { - signalPassFailure(); - } } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 5bf8ba7ec07657..c870e7be4087a7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -12,14 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project @@ -28,6 +31,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" @@ -41,6 +46,13 @@ namespace mlir::quant::stablehlo { namespace { constexpr StringRef kQuantizeTargetOpAttr = "tf_quant.composite_function"; +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; + +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Default platform for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; class ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : public impl:: @@ -125,7 +137,7 @@ class LiveOuts { // Delete the current op from liveouts and moves on to the parent ops. void update(Operation& op) { for (Value result_value : op.getResults()) { - liveouts_.erase(result_value); + liveouts_.remove(result_value); } for (Value operand : op.getOperands()) { liveouts_.insert(operand); @@ -136,19 +148,20 @@ class LiveOuts { void snapshot_previous_state() { prev_liveouts_ = liveouts_; } // Return the current live values. - const DenseSet& get() const { return liveouts_; } + const SetVector& get() const { return liveouts_; } // Return the previous live values. - const DenseSet& get_previous() const { return prev_liveouts_; } + const SetVector& get_previous() const { return prev_liveouts_; } private: - DenseSet liveouts_; - DenseSet prev_liveouts_; + // Use SerVector to ensure deterministic traversal order. + SetVector liveouts_; + SetVector prev_liveouts_; }; // Creates the tf.XlaCallModuleOp from attributes. -void CreateXlaCallModuleOp(ArrayRef inputs, ArrayRef outputs, - ArrayRef result_types, +void CreateXlaCallModuleOp(ValueRange inputs, ValueRange outputs, + TypeRange result_types, ArrayRef reverse_subgraph, func::FuncOp stablehlo_func_op, ModuleOp module_op) { MLIRContext* ctx = module_op.getContext(); @@ -163,19 +176,26 @@ void CreateXlaCallModuleOp(ArrayRef inputs, ArrayRef outputs, tf_type::ShapeAttr::get(ctx, result_type.cast())); } auto empty_array_attr = ArrayAttr::get(ctx, {}); - // TODO - b/303363466: Allow XlaCallModuleOp with versions >5. + // TODO - b/310291615: Support platforms = ["TPU"]. + auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); + auto xla_call_module_op = builder.create( module_op.getLoc(), /*output=*/result_types, /*args=*/inputs, - /*version=*/5, /*module=*/"", + /*version=*/kDefaultVersion, /*module=*/"", /*Sout=*/ArrayAttr::get(ctx, shape_attrs), - /*dim_args_spec=*/empty_array_attr, - /*platforms=*/empty_array_attr, + /*dim_args_spec=*/empty_array_attr, platforms, /*function_list=*/empty_array_attr, /*has_token_input_output=*/false, /*disabled_checks=*/empty_array_attr); xla_call_module_op->setAttr(TF::kStablehloEntryFunctionAttrName, SymbolRefAttr::get(stablehlo_func_op)); + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + xla_call_module_op->setAttr( + kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); for (auto [original_output_value, xla_call_module_op_result_value] : llvm::zip_equal(outputs, xla_call_module_op->getResults())) { @@ -251,18 +271,18 @@ void ReplaceStablehloOpsWithXlaCallModuleOp( // Contains the actual logic for updating states and replacing StableHLO ops // with tf.XlaCallModuleOps. void UpdateStatesAndReplaceStablehloOps( - const DenseSet& operands, const DenseSet& defined_values, + const SetVector& operands, const SetVector& defined_values, const LiveOuts& liveouts, ModuleOp module_op, ArrayRef reverse_subgraph, const int stablehlo_func_id, func::FuncOp main_func, const bool is_last_subgraph = false) { - DenseSet inputs = operands; + SetVector inputs = operands; for (Value defined_value : defined_values) { - inputs.erase(defined_value); + inputs.remove(defined_value); } - DenseSet outputs = liveouts.get_previous(); + SetVector outputs = liveouts.get_previous(); for (Value live_value : liveouts.get()) { - outputs.erase(live_value); + outputs.remove(live_value); } if (is_last_subgraph) { @@ -270,7 +290,7 @@ void UpdateStatesAndReplaceStablehloOps( // throughout (functions as an invisible op above the very first op that // returns the arguments). for (const BlockArgument arg : main_func.getArguments()) { - outputs.erase(arg); + outputs.remove(arg); } } @@ -298,20 +318,65 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( // statement is not included in any subgraph (e.g. XlaCallModuleOp) and is // untouched. SmallVector reverse_main_func_block_ops; + SetVector ops_to_add; for (Operation& main_func_block_op : llvm::reverse(main_func_block.without_terminator())) { reverse_main_func_block_ops.push_back(&main_func_block_op); + ops_to_add.insert(&main_func_block_op); } // Create a separate subgraph invoked with XlaCallModuleOp per each // set of StableHLO ops in the main func block. SmallVector reverse_subgraph; - DenseSet operands; - DenseSet defined_values; + SetVector operands; + SetVector defined_values; + + // Add op to the subgraph. + auto add_to_subgraph = [&](Operation* op) { + // Move on to the parent ops. + liveouts.update(*op); + ops_to_add.remove(op); + + if (!IsStablehloOp(op)) { + // Always update the liveouts when the subgraph isn't being continued. + liveouts.snapshot_previous_state(); + return; + } + + reverse_subgraph.push_back(op); + defined_values.insert(op->getResults().begin(), op->getResults().end()); + operands.insert(op->getOperands().begin(), op->getOperands().end()); + }; int stablehlo_func_id = -1; for (Operation* op : reverse_main_func_block_ops) { + if (!ops_to_add.contains(op)) continue; + // When hitting a non-StableHLO op, i.e. tf.CustomAggregatorOp, start + // recursively tracing defining ops of the current subgraph's operands. This + // makes sure that all dependencies needed for shape inference are included + // in the subgraph. Tracing stops when hitting a non-StableHLO ops or an op + // with multiple uses. In case of the latter scenario, we have to stop + // because otherwise other users of the op will become dangling references. + // TODO: b/311239049 - Consider rewrite this using BFS. if (!IsStablehloOp(op)) { + bool should_add_op = true; + while (should_add_op) { + should_add_op = false; + Operation* defining_op = nullptr; + for (Value v : operands) { + if (defined_values.contains(v)) continue; + // Check if op has branch and skip if so. + if (v.getDefiningOp() && IsStablehloOp(v.getDefiningOp()) && + v.getDefiningOp()->hasOneUse()) { + defining_op = v.getDefiningOp(); + should_add_op = true; + break; + } + } + if (should_add_op) { + add_to_subgraph(defining_op); + } + } // Create an XlaCallModuleOp if reverse_subgraph isn't empty. if (!reverse_subgraph.empty()) { UpdateStatesAndReplaceStablehloOps(operands, defined_values, liveouts, @@ -324,20 +389,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( defined_values.clear(); } } - - // Move on to the parent ops. - liveouts.update(*op); - - if (!IsStablehloOp(op)) { - // Always update the liveouts when the subgraph isn't being continued. - liveouts.snapshot_previous_state(); - continue; - } - - reverse_subgraph.push_back(op); - - defined_values.insert(op->getResults().begin(), op->getResults().end()); - operands.insert(op->getOperands().begin(), op->getOperands().end()); + add_to_subgraph(op); } // Create the last subgraph if it isn't empty. @@ -348,6 +400,37 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( } } +// Duplicate small constants for each use. +// +// In the subsequent graph partitioning, constants for shape inference need to +// be in the same subgraph. But graph partitioning stops at ops with multiple +// uses. So here we duplicate small constants for each use so that if a +// constant is useful for shape inference for multiple subgraphs, they can be +// included in each subgraphs. If duplicate constants are accidentally created +// in the same subgraph, they can be easily removed with a canonicalizer pass. +// +// We set a size limit since constants needed for shape inference are no +// larger than tensor rank. This avoids duplicating large constants. +void DuplicateSmallConstantOps(ModuleOp module_op, func::FuncOp main_func) { + OpBuilder builder(main_func.getContext()); + for (auto constant_op : + main_func.getBody().getOps()) { + builder.setInsertionPointAfter(constant_op); + if (constant_op.getResult().use_empty() || + constant_op.getResult().hasOneUse()) + continue; + // Do not duplicate constant op if the size is too large. + // 32 is chosen to be larger than all constants useful for shape references, + // while not too large to possibly significantly increase model size. + if (constant_op.getValue().getNumElements() > 32) continue; + while (!constant_op.getResult().hasOneUse()) { + auto new_constant_op = builder.clone(*constant_op.getOperation()); + constant_op.getResult().getUses().begin()->assign( + dyn_cast(new_constant_op)); + } + } +} + void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: runOnOperation() { ModuleOp module_op = getOperation(); @@ -355,14 +438,15 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: func::FuncOp main_func = GetMainFunc(module_op); if (!main_func) return; + DuplicateSmallConstantOps(module_op, main_func); ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps(module_op, main_func); // TODO - b/298966126: Currently quantizable functions are identified in TF - // Quantizer via the tf_quant.composite_function UnitAttr attached to func - // ops. We remove this attribute as this interferes with VHLO conversion. + // Quantizer via the tf_quant.composite_function UnitAttr attached to + // func ops. We remove this attribute as this interferes with VHLO conversion. // Remove this temporary hack. for (auto func_op : module_op.getOps()) { - func_op->removeAttr(kQuantizeTargetOpAttr); + func_op->removeAttr(kFusedFunctionAttr); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/restore_function_name.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/restore_function_name.cc index 545d36b625b532..57b6a2a07a04d1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/restore_function_name.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/restore_function_name.cc @@ -24,7 +24,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" @@ -54,12 +54,12 @@ class RestoreFunctionNamePass void RestoreFunctionNameFromXlaCallModuleOp(TF::XlaCallModuleOp& call_op, SymbolTable& symbol_table) { - if (!call_op->hasAttr(mlir::quant::kOriginalStablehloEntryFunctionAttrName)) { + if (!call_op->hasAttr(kOriginalStablehloEntryFunctionAttrName)) { return; } auto original_function_name = call_op->getAttrOfType( - mlir::quant::kOriginalStablehloEntryFunctionAttrName); + kOriginalStablehloEntryFunctionAttrName); auto current_function_name = call_op->getAttrOfType( TF::kStablehloEntryFunctionAttrName); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc new file mode 100644 index 00000000000000..a65694a7a7287f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/unwrap_xla_call_module_op.cc @@ -0,0 +1,122 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_UNWRAPXLACALLMODULEOPPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { + +// Unwraps XlaCallModule ops without quantizable trait that call function with +// '_from_xla_call_module' trait. +class UnwrapXlaCallModuleOpPass + : public impl::UnwrapXlaCallModuleOpPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnwrapXlaCallModuleOpPass) + + explicit UnwrapXlaCallModuleOpPass() = default; + + private: + void runOnOperation() override; +}; + +void UnwrapXlaCallModuleOp(TF::XlaCallModuleOp call_op, + SymbolTable& symbol_table) { + // Do not inline lifted quantized functions used for fusing patterns. + // TODO - b/310539922: Remove reference to TF/TFL utils. + if (call_op->hasAttr(kQuantTraitAttrName)) { + return; + } + + auto function_name = call_op + ->getAttrOfType( + TF::kStablehloEntryFunctionAttrName) + .getValue(); + func::FuncOp func_op = symbol_table.lookup(function_name); + + // We should not unwrap if the function is not from + // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass. + if (!func_op->hasAttr(TF::kFromXlaCallModuleAttrName)) { + return; + } + + MLIRContext* context = call_op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(call_op); + + IRMapping arg_mapper; + for (auto [func_arg, operand] : + llvm::zip_equal(func_op.getArguments(), call_op.getOperands())) { + arg_mapper.map(func_arg, operand); + } + + Region& function_body = func_op.getBody(); + IRMapping new_op_mapper; + for (Operation& op : function_body.getOps()) { + if (llvm::isa(op)) { + for (auto [call_result, return_value] : + llvm::zip_equal(call_op.getResults(), op.getOperands())) { + Value new_result = new_op_mapper.lookup(return_value); + + call_result.replaceAllUsesWith(new_result); + } + continue; + } + + Operation& new_op = *builder.clone(op, arg_mapper); + for (auto [result, new_result] : + llvm::zip_equal(op.getResults(), new_op.getResults())) { + new_op_mapper.map(result, new_result); + } + } + + call_op.erase(); +} + +void UnwrapXlaCallModuleOpPass::runOnOperation() { + ModuleOp module_op = getOperation(); + SymbolTable symbol_table(module_op); + + for (auto func_op : module_op.getOps()) { + Region& function_body = func_op.getBody(); + + function_body.walk([&](TF::XlaCallModuleOp call_op) { + UnwrapXlaCallModuleOp(call_op, symbol_table); + }); + } +} + +} // namespace + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td deleted file mode 100644 index 744637d58d8760..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/utils.td +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2023 The StableHLO Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -//===----------------------------------------------------------------------===// -// Helper functions. -//===----------------------------------------------------------------------===// - -// Checks whether the value of a constant equals the given float, regardless -// of the tensor dimension. -class FloatValueEquals : Constraint>; - -// Fetches the default or null attribute, used for pattern matching. -def DefaultOrNullAttr : NativeCodeCall<"DefaultOrNullAttr($_builder, $0)">; - -// Returns true if the given op is a StableHLO constant op. -def IsStableHLOConstantOp : Constraint($0.getDefiningOp())">>; - diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD new file mode 100644 index 00000000000000..00503b7ce45a0d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -0,0 +1,110 @@ +load("//tensorflow:pytype.default.bzl", "pytype_strict_library") +load( + "//tensorflow:tensorflow.default.bzl", + "tf_py_strict_test", + "tf_python_pybind_extension", +) +load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") + +package_group( + name = "internal_visibility_allowlist_package", + packages = [ + "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/compiler/mlir/quantization/...", + "//tensorflow/compiler/mlir/tf2xla/transforms/...", + "//tensorflow/lite/...", + "//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1 + ] + internal_visibility_allowlist(), +) + +package( + # copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"], + default_visibility = [ + ":internal_visibility_allowlist_package", + "//tensorflow:__pkg__", + ], + licenses = ["notice"], +) + +pytype_strict_library( + name = "quantization", + srcs = ["quantization.py"], + deps = [ + ":pywrap_quantization", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:save_model", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/saved_model:loader", + ], +) + +pytype_strict_library( + name = "quantize_model_test_base", + testonly = 1, + srcs = ["integration_test/quantize_model_test_base.py"], + tags = ["no_pip"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/module", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/saved_model:save", + "//tensorflow/python/types:core", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +tf_py_strict_test( + name = "quantize_model_test", + srcs = ["integration_test/quantize_model_test.py"], + deps = [ + ":quantization", + ":quantize_model_test_base", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:tag_constants", + "@absl_py//absl/testing:parameterized", + ], +) + +tf_python_pybind_extension( + name = "pywrap_quantization", + srcs = ["pywrap_quantization.cc"], + pytype_srcs = ["pywrap_quantization.pyi"], + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:assign_ids", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_abseil//pybind11_abseil:status_casters", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py new file mode 100644 index 00000000000000..a59d4a988c9b79 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -0,0 +1,349 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import itertools +from typing import Optional, Sequence + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization +from tensorflow.compiler.mlir.quantization.stablehlo.python.integration_test import quantize_model_test_base +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import tag_constants + +# Type aliases for quantization method protobuf enums. +_PresetMethod = quant_opts_pb2.QuantizationMethod.PresetMethod + + +def parameter_combinations(test_parameters): + """Generate all combinations of test parameters.""" + real_parameters = [] + for parameters in test_parameters: + keys = parameters.keys() + for curr in itertools.product(*parameters.values()): + real_parameters.append(dict(zip(keys, curr))) + return real_parameters + + +# Test cases for Static Range Quantization. +# Tries to run all tests cases in both the graph mode (default in TF1) and the +# eager mode (default in TF2) to ensure support for when TF2 is disabled. +class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): + + @parameterized.parameters( + parameter_combinations([{ + 'activation_fn': [None], + 'has_bias': [True, False], + 'dim_sizes': [ + # tf.MatMul cases. + ([None, 1024], [1024, 3]), # dynamic batch dim. + ([1, 1024], [1024, 3]), + # tf.BatchMatMul cases. + ([10, 1, 1024], [10, 1024, 3]), + ([2, 3, 1, 1024], [2, 3, 1024, 3]), + ], + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_matmul_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + dim_sizes: Sequence[int], + ): + target_opset = quant_opts_pb2.STABLEHLO + + lhs_dim_size, rhs_dim_size = dim_sizes + input_shape = (*lhs_dim_size,) + filter_shape = (*rhs_dim_size,) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + model = self._create_matmul_model( + input_shape, + filter_shape, + self._input_saved_model_path, + has_bias, + activation_fn, + ) + + rng = np.random.default_rng(seed=1235) + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 + ), + tags={tag_constants.SERVING}, + signature_keys=['serving_default'], + op_set=target_opset, + representative_datasets={ + 'serving_default': quant_opts_pb2.RepresentativeDatasetFile( + tfrecord_file_path=dataset_path + ) + }, + calibration_options=quant_opts_pb2.CalibrationOptions( + calibration_method=quant_opts_pb2.CalibrationOptions.CALIBRATION_METHOD_MIN_MAX + ), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol value is + # arbitrary. + # TODO: b/309674337 - Fix the large numerical errors. + self.assertAllClose(new_outputs, expected_outputs, atol=0.3) + + @parameterized.parameters( + parameter_combinations([{ + 'same_scale_op': [ + 'concatenate', + 'gather', + 'pad', + 'reshape', + 'select', + 'slice', + 'transpose', + ], + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_matmul_and_same_scale_ptq_model( + self, + same_scale_op: str, + ): + target_opset = quant_opts_pb2.STABLEHLO + + input_shape = (2, 3, 1, 1024) + filter_shape = (2, 3, 1024, 3) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + + model = self._create_matmul_and_same_scale_model( + input_shape, + filter_shape, + self._input_saved_model_path, + same_scale_op, + ) + + rng = np.random.default_rng(seed=1235) + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 + ), + tags={tag_constants.SERVING}, + signature_keys=['serving_default'], + op_set=target_opset, + representative_datasets={ + 'serving_default': quant_opts_pb2.RepresentativeDatasetFile( + tfrecord_file_path=dataset_path + ) + }, + calibration_options=quant_opts_pb2.CalibrationOptions( + calibration_method=quant_opts_pb2.CalibrationOptions.CALIBRATION_METHOD_MIN_MAX + ), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.matmul_and_same_scale(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol value is + # arbitrary. + # TODO: b/309674337 - Fix the large numerical errors. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.3) + + @parameterized.named_parameters( + { + 'testcase_name': 'none', + 'activation_fn': None, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.STABLEHLO, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + ) + @test_util.run_in_graph_and_eager_modes + def test_conv_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + has_batch_norm: bool, + target_opset: quant_opts_pb2.OpSet, + input_shape_dynamic: bool, + enable_per_channel_quantization: bool, + dilations: Sequence[int] = None, + ): + input_shape = (None, None, None, 3) if input_shape_dynamic else (1, 3, 4, 3) + filter_shape = (2, 3, 3, 2) + strides = (1, 1, 1, 1) + model = self._create_conv2d_model( + input_shape, + filter_shape, + self._input_saved_model_path, + has_bias, + has_batch_norm, + activation_fn, + strides, + dilations, + ) + + # Generate model input data. + rng = np.random.default_rng(seed=1224) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(100): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) + } + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + tags = {tag_constants.SERVING} + + config = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 + ), + tags=tags, + signature_keys=['serving_default'], + op_set=target_opset, + representative_datasets={ + 'serving_default': quant_opts_pb2.RepresentativeDatasetFile( + tfrecord_file_path=dataset_path + ) + }, + enable_per_channel_quantization=enable_per_channel_quantization, + calibration_options=quant_opts_pb2.CalibrationOptions( + calibration_method=quant_opts_pb2.CalibrationOptions.CALIBRATION_METHOD_MIN_MAX + ), + ) + + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.conv2d(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol value is + # arbitrary. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.04) + + def test_when_preset_not_srq_raise_error(self): + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) + + config = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_NO_QUANTIZE + ), + tags={tag_constants.SERVING}, + signature_keys=['serving_default'], + op_set=quant_opts_pb2.STABLEHLO, + ) + + with self.assertRaisesRegex(ValueError, 'only supports static-range PTQ'): + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py new file mode 100644 index 00000000000000..86f7fadb671e1e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -0,0 +1,299 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base test class for quantize_model Tests.""" +from typing import Mapping, Sequence, Optional + +from absl.testing import parameterized +import numpy as np +import tensorflow # pylint: disable=unused-import + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.module import module +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test +from tensorflow.python.saved_model import save as saved_model_save +from tensorflow.python.types import core + + +class QuantizedModelTest(test.TestCase, parameterized.TestCase): + """Base test class for StableHLO quant tests.""" + + def setUp(self) -> None: + super().setUp() + + # Many test cases for quantization involve creating and saving the input + # model and saving the output quantized model. These two member + # attributes can be used to specify the paths for such models, + # respectively. These paths will be cleaned up after each test case. + self._input_saved_model_path = self.create_tempdir('input').full_path + self._output_saved_model_path = self.create_tempdir('output').full_path + # Extra output path occasionally used for comparing two different + # quantized models. + self._output_saved_model_path_2 = self.create_tempdir('output2').full_path + + def _create_matmul_model( + self, + input_shape: Sequence[int], + weight_shape: Sequence[int], + saved_model_path: str, + has_bias: bool = False, + activation_fn: Optional[ops.Operation] = None, + bias_size: Optional[int] = None, + use_biasadd: bool = True, + ) -> module.Module: + class MatmulModel(module.Module): + """A simple model with a single matmul. + + Bias and activation function are optional. + """ + + def __init__( + self, + weight_shape: Sequence[int], + bias_size: Optional[int] = None, + activation_fn: Optional[ops.Operation] = None, + use_biasadd: bool = True, + ) -> None: + """Initializes a MatmulModel. + + Args: + weight_shape: Shape of the weight tensor. + bias_size: If None, do not use bias. Else, use given size as bias. + activation_fn: The activation function to be used. No activation + function if None. + use_biasadd: If True, use BiasAdd for adding bias, else use AddV2. + """ + self.bias_size = bias_size + self.activation_fn = activation_fn + self.use_biasadd = use_biasadd + self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) + + if bias_size is not None: + self.bias = np.random.uniform(low=-1.0, high=1.0, size=bias_size) + + def has_bias(self) -> bool: + return self.bias_size is not None + + def has_reshape(self) -> bool: + return self.has_bias() and self.bias_size != self.filters.shape[-1] + + @def_function.function + def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs a matrix multiplication. + + Depending on self.has_bias and self.activation_fn, it may add a bias + term or + go through the activaction function. + + Args: + input_tensor: Input tensor to matmul with the filter. + + Returns: + A map of: output key -> output result. + """ + out = math_ops.matmul(input_tensor, self.filters, name='sample/matmul') + + return {'output': out} + + # If bias_size is not explictly given, it should default to width of weight. + if bias_size is None and has_bias: + bias_size = weight_shape[-1] + + # Verify that when bias_size is not None, has_bias should be True. + # And if bias_size is None, has_bias should be False. + assert (bias_size is None) != has_bias + + model = MatmulModel(weight_shape, bias_size, activation_fn) + saved_model_save.save( + model, + saved_model_path, + signatures=model.matmul.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model + + def _create_matmul_and_same_scale_model( + self, + input_shape: Sequence[int], + weight_shape: Sequence[int], + saved_model_path: str, + same_scale_op: str, + ) -> module.Module: + class MatmulAndSameScaleModel(module.Module): + """A simple model with a same-scale op. + + Op name in StableHLO dialect is given as a string. + """ + + def __init__( + self, + weight_shape: Sequence[int], + same_scale_op: str, + ) -> None: + """Initializes a MatmulModel. + + Args: + weight_shape: Shape of the weight tensor. + same_scale_op: Name of the same-scale op to be tested. Raises error + when an unknown name is given. + """ + self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) + self.same_scale_op = same_scale_op + + @def_function.function + def matmul_and_same_scale( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: + """Performs a matrix multiplication. + + Args: + input_tensor: Input tensor to matmul with the filter. + + Returns: + A map of: output key -> output result. + """ + out = math_ops.matmul(input_tensor, self.filters, name='sample/matmul') + + if self.same_scale_op == 'concatenate': + ones = array_ops.ones_like(out) + out = array_ops.concat([out, ones], 0) + elif self.same_scale_op == 'gather': + out = array_ops.gather(out, indices=[0], axis=0) + elif self.same_scale_op == 'pad': + paddings = array_ops.ones( + (array_ops.rank(out), 2), dtype=dtypes.int32 + ) + out = array_ops.pad(out, paddings, 'CONSTANT') + elif self.same_scale_op == 'reshape': + out = array_ops.reshape(out, (array_ops.size(out), -1)) + elif self.same_scale_op == 'select': + rng = np.random.default_rng(seed=1234) + condition = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=out.shape) < 0.5 + ) + ones = array_ops.ones_like(out) + out = math_ops.select(condition, out, ones) + elif self.same_scale_op == 'slice': + begin = array_ops.zeros( + (array_ops.rank(out)), dtype=dtypes.int32 + ) + size = array_ops.ones( + (array_ops.rank(out)), dtype=dtypes.int32 + ) + out = array_ops.slice(out, begin, size) + elif self.same_scale_op == 'transpose': + out = array_ops.transpose(out) + else: + raise NotImplementedError( + '{} is not implemented for integration test.'.format( + self.same_scale_op + ) + ) + + return {'output': out} + + model = MatmulAndSameScaleModel(weight_shape, same_scale_op) + saved_model_save.save( + model, + saved_model_path, + signatures=model.matmul_and_same_scale.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model + + def _create_conv2d_model( + self, + input_shape: Sequence[int], + filter_shape: Sequence[int], + saved_model_path: str, + has_bias: bool = False, + has_batch_norm: bool = False, + activation_fn: Optional[ops.Operation] = None, + strides: Sequence[int] = (1, 1, 1, 1), + dilations: Sequence[int] = (1, 1, 1, 1), + padding: str = 'SAME', + ) -> module.Module: + class ConvModel(module.Module): + """A simple model with a single conv2d, bias and relu.""" + + def __init__(self): + self.out_channel_size = filter_shape[-1] + + # This ensures filters will have different value range per out channel + self.filters = np.stack( + [ + np.random.uniform( + low=-(i + 1), high=(i + 1), size=filter_shape[:-1] + ).astype('f4') + for i in range(self.out_channel_size) + ], + axis=-1, + ) + + self.bias = np.random.uniform( + low=0, high=10, size=(self.out_channel_size) + ).astype('f4') + + @def_function.function + def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs a 2D convolution operation. + + Args: + input_tensor: Input tensor to perform convolution on. + + Returns: + A map of: output key -> output result. + """ + scale = [1.0] * self.out_channel_size + offset = [0.5] * self.out_channel_size + mean, variance = scale, offset + out = nn_ops.conv2d( + input_tensor, + self.filters, + strides=strides, + dilations=dilations, + padding=padding, + data_format='NHWC', + name='sample/conv', + ) + if has_batch_norm: + # Fusing is supported for non-training case. + out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( + out, scale, offset, mean, variance, is_training=False + ) + return {'output': out} + + model = ConvModel() + saved_model_save.save( + model, + saved_model_path, + signatures=model.conv2d.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc new file mode 100644 index 00000000000000..08fac877481157 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -0,0 +1,187 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "pybind11/cast.h" // from @pybind11 +#include "pybind11/detail/common.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11/stl.h" // from @pybind11 // IWYU pragma: keep +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace py = pybind11; + +namespace { + +using ::stablehlo::quantization::AddCalibrationStatistics; +using ::stablehlo::quantization::AssignIdsToCustomAggregatorOps; +using ::stablehlo::quantization::EnableDebugging; +using ::stablehlo::quantization::io::CreateTmpDir; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::QuantizationOptions; + +} // namespace + +PYBIND11_MODULE(pywrap_quantization, m) { + // Supports absl::Status type conversions. + pybind11::google::ImportStatusModule(); + + m.doc() = "StableHLO Quantization APIs."; + + m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange + "static_range_ptq", + [](const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + const QuantizationOptions& quantization_options, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const absl::flat_hash_map& function_aliases, + const PyFunctionLibrary& py_function_library, + py::object representative_dataset) -> absl::Status { + // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) + std::unordered_set tags; + tags.insert(quantization_options.tags().begin(), + quantization_options.tags().end()); + + absl::StatusOr exported_model = + QuantizePtqModelPreCalibration(src_saved_model_path, signature_keys, + tags, quantization_options, + function_aliases); + if (!exported_model.ok()) return exported_model.status(); + + AssignIdsToCustomAggregatorOps(*exported_model->mutable_graph_def()); + + const absl::StatusOr precalibrated_saved_model_dir = + CreateTmpDir(); + if (!precalibrated_saved_model_dir.ok()) { + throw py::value_error(absl::StrFormat( + "Failed to create tmp dir for precalibrated saved model: %s", + precalibrated_saved_model_dir.status().ToString())); + } + + py_function_library.SaveExportedModel( + *precalibrated_saved_model_dir, *exported_model, + src_saved_model_path, tags, signature_def_map); + + py_function_library.RunCalibration( + *precalibrated_saved_model_dir, signature_keys, tags, + quantization_options.calibration_options(), + quantization_options.force_graph_mode_calibration(), + representative_dataset); + + if (absl::Status status = AddCalibrationStatistics( + *exported_model->mutable_graph_def(), + quantization_options.calibration_options(), + py_function_library); + !status.ok()) { + LOG(WARNING) << "Some CustomAggregator ops do not have min or max " + "values. Parts of the graph are not quantized. " + << status; + } + + if (quantization_options.has_debugger_options()) { + EnableDebugging(*exported_model, + quantization_options.debugger_options(), + py_function_library, src_saved_model_path, tags, + signature_def_map); + } + + const absl::StatusOr calibrated_saved_model_path = + CreateTmpDir(); + if (!calibrated_saved_model_path.ok()) { + throw py::value_error(absl::StrFormat( + "Failed to create tmp dir for calibrated saved model: %s", + calibrated_saved_model_path.status().ToString())); + } + + py_function_library.SaveExportedModel( + *calibrated_saved_model_path, *exported_model, src_saved_model_path, + tags, signature_def_map); + + const absl::flat_hash_map + function_aliases_after_calibration( + exported_model->function_aliases().begin(), + exported_model->function_aliases().end()); + + const absl::StatusOr post_calibrated_exported_model = + QuantizePtqModelPostCalibration( + *calibrated_saved_model_path, signature_keys, tags, + quantization_options, function_aliases_after_calibration); + if (!post_calibrated_exported_model.ok()) { + return post_calibrated_exported_model.status(); + } + + // Remove the `tpu` tag from the debug quantized saved model as it is + // for CPU. Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + if (quantization_options.has_debugger_options()) { + tags.erase("tpu"); + } + py_function_library.SaveExportedModel( + dst_saved_model_path, *post_calibrated_exported_model, + *calibrated_saved_model_path, tags, signature_def_map); + + return absl::OkStatus(); + }, + R"pbdoc( + Runs static-range post-training quantization (PTQ) on a SavedModel at + `src_saved_model_path` and saves the resulting model to + `dst_saved_model_path`. + + The user should pass a serialized `QuantizationOptions` for the + `quantization_options_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + `function_aliases` maps actual function names to the function aliases, as + defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the + input SavedModel. + + Raises `StatusNotOk` exception if when the run was unsuccessful. + )pbdoc", + py::arg("saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_options_serialized"), py::kw_only(), + py::arg("signature_keys"), py::arg("signature_def_map_serialized"), + py::arg("function_aliases"), py::arg("py_function_library"), + py::arg("representative_dataset")); +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi new file mode 100644 index 00000000000000..1870115a4aa847 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi @@ -0,0 +1,33 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Any + +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd + +# LINT.IfChange(static_range_ptq) +def static_range_ptq( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], + function_aliases: dict[str, str], + py_function_library: py_function_lib.PyFunctionLibrary, + representative_dataset: rd.RepresentativeDatasetOrMapping, +) -> Any: ... # Status + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py new file mode 100644 index 00000000000000..fab36e2005110f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -0,0 +1,100 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""StableHLO Quantizer.""" +from typing import Mapping + +from tensorflow.compiler.mlir.quantization.stablehlo.python import pywrap_quantization +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd +from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.saved_model import loader_impl + +# Mapping of signature def key -> SignatureDef. +_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] + + +def _serialize_signature_def_map( + signature_def_map: _SignatureDefMap, +) -> dict[str, bytes]: + """Serializes SignatureDef values in `signature_def_map`. + + Args: + signature_def_map: Signature key -> SignatureDef mapping. + + Returns: + Signature def map where the values (`SignatureDef`) are serialized. + """ + signature_def_map_serialized = {} + for key, signature_def in signature_def_map.items(): + signature_def_map_serialized[key] = signature_def.SerializeToString() + + return signature_def_map_serialized + + +# TODO: b/310594193 - Export API to pip package. +def quantize_saved_model( + src_saved_model_path: str, + dst_saved_model_path: str, + config: quant_opts_pb2.QuantizationOptions, +) -> None: + """Quantizes a saved model. + + Args: + src_saved_model_path: Path to the directory for the source SavedModel. + dst_saved_model_path: Path to the directory for the destination SavedModel. + config: Quantization configuration. + + Raises: + ValueError: When `config` was not configured for static-range PTQ + single representative dataset. + """ + if not ( + config.quantization_method.preset_method + == quant_opts_pb2.QuantizationMethod.PresetMethod.METHOD_STATIC_RANGE_INT8 + and len(config.representative_datasets) == 1 + ): + raise ValueError( + '`quantize_saved_model` currently only supports static-range PTQ with a' + ' single signature.' + ) + + signature_def_map = save_model.get_signatures_from_saved_model( + src_saved_model_path, + list(config.signature_keys), + set(config.tags), + ) + + loader = loader_impl.SavedModelLoader(src_saved_model_path) + function_aliases = loader.get_meta_graph_def_from_tags( + config.tags + ).meta_info_def.function_aliases + + representative_dataset = rd.RepresentativeDatasetLoader( + config.representative_datasets + ).load() + + signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) + pywrap_quantization.static_range_ptq( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=config.SerializeToString(), + signature_keys=list(config.signature_keys), + signature_def_map_serialized=signature_def_map_serialized, + function_aliases=dict(function_aliases), + py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset=representative_dataset, + ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto new file mode 100644 index 00000000000000..c28e95da07004f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -0,0 +1,49 @@ +// Protobuf messages for configuring StableHLO Quantizer. +syntax = "proto3"; + +package stablehlo.quantization; + +option cc_enable_arenas = true; + +// Represents a single TFRecord file. See +// https://www.tensorflow.org/tutorials/load_data/tfrecord for details on the +// TFRecord format. +// Next ID: 2 +message TfRecordFile { + string path = 1; +} + +// Configures a single representative dataset used to calibrate a single +// function. +// Next ID: 3 +message RepresentativeDatasetConfig { + oneof file { + // Represents representative dataset saved as a .tfrecord file format. + TfRecordFile tf_record = 1; + } + + // [TF SavedModel] Identifies a SignatureDef which represents a single + // logical function in a graph. + optional string signature_key = 2; +} + +// Preset config for static-range post-training quantization (PTQ). +// Minimal user input about representative datasets is required. Representative +// datasets are required for static-range PTQ to retrieve quantization +// statistics via calibration. +// Next ID: 2 +message StaticRangePtqPreset { + // Configures representative dataset. Each item corresponds to a + // representative dataset used to calibrate a function. + repeated RepresentativeDatasetConfig representative_datasets = 1; +} + +// Quantization configuration for StableHLO Quantizer. This is the primary +// message containing all configurable options. +// Next ID: 2 +message QuantizationConfig { + oneof preset { + // Performs best-effort static-range post-training quantization (PTQ). + StaticRangePtqPreset static_range_ptq_preset = 1; + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD index 4c078033215618..6fc15864fb0f8b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -46,3 +46,24 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", ], ) + +tf_cc_test( + name = "stablehlo_op_quant_spec_test", + srcs = ["stablehlo_op_quant_spec_test.cc"], + deps = [ + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:QuantOps", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 65c8497aa9a41a..713c55281ff0e3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -95,12 +95,15 @@ func.func @uniform_quantize_and_dequantize_type_exensions(%arg0: tensor (d0 : compressed) }> + +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> // CHECK-LABEL: func @uniform_quantize_and_dequantize_sparse_tensor_encoding -func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor (d0 : compressed) }>>) -> () { - // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor (d0 : compressed) }>>) -> tensor (d0 : compressed) }>> - %0 = mhlo.uniform_quantize %arg0 : (tensor (d0 : compressed) }>>) -> tensor, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>> - // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor (d0 : compressed) }>>, tensor) -> tensor (d0 : compressed) }>> - %1 = mhlo.uniform_dequantize %0 : (tensor, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor (d0 : compressed) }>> +func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor, #SV> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor, tensor) -> tensor + %1 = mhlo.uniform_dequantize %0 : (tensor, #SV>) -> tensor return } @@ -341,6 +344,91 @@ func.func @requantize_merged_zp_zero( // ----- +// CHECK-LABEL: func @requantize_per_channel +func.func @requantize_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_channel_to_per_tensor +func.func @requantize_per_channel_to_per_tensor( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_tensor_to_per_channel +func.func @requantize_per_tensor_to_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-1.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +func.func @requantize_per_channel_change_axis( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // expected-error@+2 {{Cannot requantize while changing quantization_axis}} + // expected-error@+1 {{failed to legalize operation 'mhlo.uniform_quantize' that was explicitly marked illegal}} + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @dot func.func @dot(%arg0: tensor<2x2x!quant.uniform>, %arg1: tensor<2x2x!quant.uniform> @@ -493,7 +581,6 @@ func.func @dot_dynamic_result_dim( // CHECK-SAME: broadcast_dimensions = dense<1> // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor - %0 = "mhlo.dot" (%arg0, %arg1) : ( tensor>, tensor<2x?x!quant.uniform> @@ -503,6 +590,39 @@ func.func @dot_dynamic_result_dim( // ----- +// CHECK-LABEL: func @dot_dynamic_batch_dim +func.func @dot_dynamic_batch_dim( + %arg0: tensor>, + %arg1: tensor<2x2x!quant.uniform> + ) -> tensor> { + // CHECK: "mhlo.dot_general" + // CHECK-SAME: lhs_contracting_dimensions = [1] + // CHECK-SAME: rhs_contracting_dimensions = [0] + // CHECK-SAME: (tensor, tensor<2x2xi8>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [1] + // CHECK-SAME: (tensor, tensor) -> tensor + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: (tensor, tensor<2xi64>) -> tensor + + // CHECK: mhlo.reduce + // CHECK-SAME: applies mhlo.add across dimensions = [0] + // CHECK-SAME: (tensor<2x2xi32>, tensor) -> tensor<2xi32> + // CHECK: mhlo.dynamic_broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: (tensor<2xi32>, tensor<2xi64>) -> tensor + + %0 = "mhlo.dot" (%arg0, %arg1) : ( + tensor>, + tensor<2x2x!quant.uniform> + ) -> tensor> + return %0 : tensor> +} + +// ----- + // CHECK-LABEL: func @dot_general func.func @dot_general( %arg0: tensor<2x5x6x!quant.uniform>, @@ -1113,6 +1233,27 @@ func.func @conv2d_static( // ----- +// CHECK-LABEL: func @conv2d_default_attr +func.func @conv2d_default_attr( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128x!quant.uniform> { + // CHECK: mhlo.convolution + // CHECK-NOT: quant.uniform + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x128x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @conv2d_static_padding func.func @conv2d_static_padding( %arg0: tensor<128x28x28x1x!quant.uniform>, @@ -1660,6 +1801,21 @@ func.func @broadcast( // ----- +// CHECK-LABEL: func @broadcast_per_channel +func.func @broadcast_per_channel( + %arg0: tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // CHECK: "mhlo.broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<3> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<128x26x26x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>}: ( + tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @max func.func @max( %arg0: tensor<1x2x!quant.uniform> @@ -1675,6 +1831,21 @@ func.func @max( // ----- +// CHECK-LABEL: func @max_per_channel +func.func @max_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.maximum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.maximum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @min func.func @min( %arg0: tensor<1x2x!quant.uniform> @@ -1690,6 +1861,21 @@ func.func @min( // ----- +// CHECK-LABEL: func @min_per_channel +func.func @min_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.minimum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.minimum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @function(%arg0: tensor<1x2xi8>) -> tensor<1x2xi8> func.func @function( %arg0: tensor<1x2x!quant.uniform> @@ -1700,27 +1886,124 @@ func.func @function( // ----- -func.func @min_mix_uq_type1( - %arg0: tensor<1x2x!quant.uniform>, - %arg1: tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> { - // expected-error@+1 {{failed to legalize operation 'mhlo.minimum' that was explicitly marked illegal}} - %0 = "mhlo.minimum"(%arg0, %arg1) : ( - tensor<1x2x!quant.uniform>, - tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> - return %0 : tensor<1x2x!quant.uniform> +// CHECK-LABEL: func @concatenate +func.func @concatenate( + %arg0: tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, + %arg1: tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + ) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> { + // CHECK: mhlo.concatenate + // CHECK-SAME: (tensor<3x2xi8>, tensor<1x2xi8>) -> tensor<4x2xi8> + %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : ( + tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, + tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + ) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + return %0 : tensor<4x2x!quant.uniform:f32, 5.000000e-03>> } // ----- -func.func @min_mix_uq_type2( - %arg0: tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> { - // expected-error@+1 {{failed to legalize operation 'mhlo.minimum' that was explicitly marked illegal}} - %0 = "mhlo.minimum"(%arg0, %arg0) : ( - tensor<1x2x!quant.uniform>, - tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> - return %0 : tensor<1x2x!quant.uniform> +// CHECK-LABEL: func @pad +func.func @pad( + %arg0: tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, + %arg1: tensor:f32, 5.000000e-03>> + ) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> { + // CHECK: mhlo.pad + // CHECK-SAME: (tensor<2x3xi8>, tensor) -> tensor<5x9xi8> + %0 = "mhlo.pad"(%arg0, %arg1) { + edge_padding_low = dense<[0, 1]> : tensor<2xi64>, + edge_padding_high = dense<[2, 1]> : tensor<2xi64>, + interior_padding = dense<[1, 2]> : tensor<2xi64> + }: ( + tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, + tensor:f32, 5.000000e-03>> + ) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + return %0 : tensor<5x9x!quant.uniform:f32, 5.000000e-03>> +} + +// ----- + +// CHECK-LABEL: func @reshape +func.func @reshape( + %arg0: tensor<1x3x!quant.uniform> + ) -> tensor<3x1x!quant.uniform> { + // CHECK: mhlo.reshape + // CHECK-SAME: (tensor<1x3xi8>) -> tensor<3x1xi8> + %0 = "mhlo.reshape"(%arg0) : ( + tensor<1x3x!quant.uniform> + ) -> tensor<3x1x!quant.uniform> + return %0 : tensor<3x1x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @select +func.func @select( + %arg0: tensor<1x3xi1>, + %arg1: tensor<1x3x!quant.uniform>, + %arg2: tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + // CHECK: mhlo.select + // CHECK-SAME: tensor<1x3xi8> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : ( + tensor<1x3xi1>, + tensor<1x3x!quant.uniform>, + tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @transpose +func.func @transpose( + %arg0: tensor<3x1x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + // CHECK: mhlo.transpose + // CHECK-SAME: (tensor<3x1xi8>) -> tensor<1x3xi8> + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : ( + tensor<3x1x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @gather +func.func @gather( + %arg0: tensor<3x4x2x!quant.uniform>, + %arg1: tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> { + // CHECK: mhlo.gather + // CHECK-SAME: (tensor<3x4x2xi8>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi8> + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : ( + tensor<3x4x2x!quant.uniform>, + tensor<2x3x2xi64> + ) -> tensor<2x3x2x2x!quant.uniform> + return %0 : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @slice +func.func @slice( + %arg0: tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK: mhlo.slice + // CHECK-SAME: (tensor<3x4xi8>) -> tensor<2x2xi8> + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 2]> : tensor<2xi64>, + limit_indices = dense<[3, 4]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : ( + tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir new file mode 100644 index 00000000000000..05f10405356ba9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir @@ -0,0 +1,44 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -populate-shape --mlir-print-ir-after-all | FileCheck %s + +// CHECK-LABEL: @populate_shape_for_custom_aggregator +func.func @populate_shape_for_custom_aggregator(%input: tensor) { + // CHECK: %[[OUTPUT:.*]] = "tf.CustomAggregator"(%[[INPUT:.*]]) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor + %0 = "tf.CustomAggregator"(%input) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor<*xf32> + func.return +} + +// ---- + +// CHECK-LABEL: @populate_shape_for_xla_call_module +func.func @populate_shape_for_xla_call_module(%input: tensor) { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> + // CHECK: %[[OUTPUT:.*]] = "tf.XlaCallModule"(%[[INPUT:.*]], %[[CST:.*]]) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor + %0 = "tf.XlaCallModule"(%input, %cst) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor<*xf32> + func.return +} + +// ---- + +// CHECK-LABEL: @populate_shape_for_chain_of_ops +func.func @populate_shape_for_chain_of_ops(%input: tensor) { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> + // CHECK: %[[VAL_0:.*]] = "tf.CustomAggregator"(%[[INPUT:.*]]) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor + // CHECK: %[[VAL_1:.*]] = "tf.XlaCallModule"(%[[VAL_0:.*]], %[[CST:.*]]) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor + // CHECK: %[[VAL_2:.*]] = "tf.CustomAggregator"(%[[VAL_1:.*]]) <{id = "49d53b1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor + %0 = "tf.CustomAggregator"(%input) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor<*xf32> + %1 = "tf.XlaCallModule"(%0, %cst) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<*xf32>, tensor<1x1x64x256xf32>) -> tensor<*xf32> + %2 = "tf.CustomAggregator"(%1) <{id = "49d53b1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> + func.return +} + +// ---- + +// CHECK-LABEL: @populate_shape_for_xla_call_module_failure_not_single_output +func.func @populate_shape_for_xla_call_module_failure_not_single_output(%input: tensor) { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> + // expected-error @+2 {{XlaCallModuleOp doesn't have 1 output.}} + %0, %1 = "tf.XlaCallModule"(%input, %cst) <{Sout = [#tf_type.shape, #tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> (tensor<*xf32>, tensor<*xf32>) + // expected-error @+1 {{XlaCallModuleOp doesn't have 1 output.}} + "tf.XlaCallModule"(%input, %cst) <{Sout = [], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> () + func.return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir index 8f38f889f28e33..a873f30a20cff8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir @@ -105,3 +105,36 @@ func.func @merge_consecutive_qcast(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, % %6 = "quantfork.stats"(%5) {layerStats = dense<[-1.5726943, 4.6875381]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> func.return %3, %6 : tensor<*xf32>, tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func @skip_nan_inf_constant +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @skip_nan_inf_constant(%arg0: tensor) -> tensor { + // CHECK: %[[cst0:.*]] = stablehlo.constant + // CHECK: %[[cst1:.*]] = stablehlo.constant + // CHECK: %[[cst2:.*]] = stablehlo.constant + // CHECK: %[[cst3:.*]] = stablehlo.constant + // CHECK-NOT: %[[q0:.*]] = "quantfork.qcast"(%[[cst0]]) + // CHECK-NOT: %[[q1:.*]] = "quantfork.qcast"(%[[cst1]]) + // CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[cst2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[q3:.*]] = "quantfork.qcast"(%[[cst3]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq3:.*]] = "quantfork.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0x7FC00000> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.constant dense<0.000000e+00> : tensor + %4 = "stablehlo.add"(%0, %1) : (tensor, tensor) -> tensor + %5 = stablehlo.clamp %3, %arg0, %2 : (tensor, tensor, tensor) -> tensor + %6 = "stablehlo.reduce_window"(%5, %4) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %7 : tensor + }) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor, tensor) -> tensor + return %6 : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir index d1bfea7a236448..e794dded354da9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir @@ -1,5 +1,8 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize -verify-each=false | FileCheck %s +// Tests for PopulateFusedGemmStylePatterns are handled in +// quantize_composite_functions for module-level evaluation of functions. + // CHECK-LABEL: quantize_simple_xla_call_module func.func private @quantize_simple_xla_call_module(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { %0 = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> @@ -40,3 +43,27 @@ func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf3 // CHECK: %[[XLACALLMODULE_0:.*]] = "tf.XlaCallModule"() <{{{.*}}}> {{{.*}}} : () -> tensor<1x3x!quant.uniform> // CHECK: %[[DCAST_0:.*]] = "quantfork.dcast"(%[[XLACALLMODULE_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> () + +// ----- + +// Tests for emitting an error when there is no corresponding entry +// function to quantize (@composite_dot_general_fn). + +module attributes {tf_saved_model.semantics} { +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} + func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + %2 = "quantfork.dcast"(%1) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> + %3 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// expected-error @+2 {{Failed to find a valid entry function}} +// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %6 = "quantfork.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %7 = "quantfork.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %7 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir index 97ea1f30be81ba..b6efc8cab0060a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir @@ -1,14 +1,17 @@ // RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ // RUN: -stablehlo-quantize-composite-functions | FileCheck %s + +// Tests that basic dot_general is properly quantized. + +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @quantize_dot_general(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + func.func private @quantize_dot_general(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } @@ -16,99 +19,199 @@ module attributes {tf_saved_model.semantics} { // calls the quantized entry function. // CHECK-LABEL: func.func private @quantize_dot_general -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3x3xi8>} : () -> tensor<3x3x!quant.uniform:f32, {{.*}}> -// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x!quant.uniform>, tensor<3x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> // CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } // Checks that the entry function is quantized for dot_general. Quantized // dot_general outputs an i32 quantized tensor, followed by requantization to // i8 quantized tensor. -// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_2:.*]]: tensor<1x3x!quant.uniform>, %[[ARG_3:.*]]: tensor<3x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> // CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> // CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> } // ----- -// Tests error when there are no corresponding entry function to quantize -// (@composite_dot_general_fn). +// Tests that fused pattern for dot_general + bias is properly quantized. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @error_when_no_entry_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> -// expected-error @+2 {{Failed to find a valid entry function}} -// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> + func.func private @quantize_dot_general_with_bias(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, _original_entry_function = "composite_dot_general_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %2 : tensor<1x3xf32> } + +// CHECK-LABEL: func.func private @quantize_dot_general_with_bias +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_dot_general_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK: func.func private @quantized_dot_general_with_bias_fn(%[[ARG_2:.*]]: tensor<1x2x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_4:.*]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_with_bias_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_4]] : tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + } // ----- -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. +// Tests that fused pattern for dot_general + bias with dynamic shape is +// not quantized. +// TODO: b/307620428 - Add support for fused bias with dynamic shapes. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} module attributes {tf_saved_model.semantics} { - func.func private @not_quantized_without_stats(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. + func.func private @quantize_dot_general_with_bias_dynamic(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + // expected-error@+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values, but got}} + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor } -// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is -// not quantized. -// CHECK-LABEL: func.func private @not_quantized_without_stats -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} -// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<3x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] + func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<2xindex> + %2 = stablehlo.dynamic_broadcast_in_dim %arg2, %1, dims = [1] : (tensor<3xf32>, tensor<2xindex>) -> tensor + %3 = stablehlo.add %0, %2 : tensor + return %3 : tensor + } +} - func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> +// ----- + +// Tests that basic convolution is properly quantized. + +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} +module attributes {tf_saved_model.semantics} { +// The following pattern does not converge because of a bug in QuantizePass. +// TODO - b/305469508: Fix the QuantizePass to avoid this warning. + func.func private @quantize_convolution(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_convolution_fn, _original_entry_function = "composite_convolution_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> } -// Check that the composite_dot_general_fn is untouched. +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. -// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x3xf32>, %[[ARG_3:.*]]: tensor<3x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} -// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] -// CHECK: return %[[DOT_GENERAL]] +// CHECK-LABEL: func.func private @quantize_convolution +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_convolution_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_convolution_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_convolution_fn(%[[ARG_2:.*]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.*]] = stablehlo.convolution(%[[ARG_2]], %[[ARG_3]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> } // ----- -// Tests that a fusion pattern for dot_general is not yet supported. Further op -// coverage will be provided in the future. -// TODO - b/307620428: Increase op coverage to cover this test case. +// Tests that fused pattern for convolution + bias is properly quantized. +// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} module attributes {tf_saved_model.semantics} { // The following pattern does not converge because of a bug in QuantizePass. // TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} - func.func private @dot_general_fn_fusion_not_quantized(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<3x3xf32>} : () -> tensor<3x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> -// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> + func.func private @quantize_convolution_with_bias(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3x4x2xf32>} : () -> tensor<1x3x4x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_convolution_with_bias_fn, _original_entry_function = "composite_convolution_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> } - func.func private @composite_dot_general_fn(%arg0: tensor<1x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x3xf32>, tensor<3x3xf32>) -> tensor<1x3xf32> - %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> +// CHECK-LABEL: func.func private @quantize_convolution_with_bias +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.*]] = stablehlo.uniform_quantize %[[ARG_1]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.*]] = call @quantized_convolution_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.*]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK: func.func private @quantized_convolution_with_bias_fn(%[[ARG_2:.*]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_3:.*]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_4:.*]]: tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_convolution_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } +// CHECK: %[[CONVOLUTION_0:.*]] = stablehlo.convolution(%[[ARG_2]], %[[ARG_3]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[CONVOLUTION_0]], %[[ARG_4]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.*]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> return %1 : tensor<1x3xf32> } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK-LABEL: func.func private @not_quantized_without_stats +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[ARG_1]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[XLA_CALL_MODULE_0]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Check that the composite_dot_general_fn is untouched. + +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_2:.*]]: tensor<1x2xf32>, %[[ARG_3:.*]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %[[ARG_2]], %[[ARG_3]] +// CHECK: return %[[DOT_GENERAL]] } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir new file mode 100644 index 00000000000000..7878bccf9d7e61 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir @@ -0,0 +1,261 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize -verify-each=false | FileCheck %s + +// CHECK-LABEL: same_scale_after_composite +func.func @same_scale_after_composite() -> tensor<3x1xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = "stablehlo.reshape"(%[[CALL]]) : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[RESHAPE]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> +} + +// ----- + +// CHECK-LABEL: same_scale_indirectly_connected +func.func @same_scale_indirectly_connected() -> tensor<1x3xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = "stablehlo.reshape"(%[[CALL]]) : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[TRANSPOSE:.*]] = "stablehlo.transpose"(%[[RESHAPE]]) {permutation = array} : (tensor<3x1x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[TRANSPOSE]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + %6 = "stablehlo.transpose"(%5) {permutation = array} : (tensor<3x1xf32>) -> tensor<1x3xf32> + %7 = "quantfork.qcast"(%6) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %8 = "quantfork.dcast"(%7) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %8 : tensor<1x3xf32> +} + +// ----- + +// CHECK-LABEL: same_scale_not_connected_to_composite +func.func @same_scale_not_connected_to_composite() -> tensor<3x1xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[CST]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ1:.*]] = "quantfork.dcast"(%[[Q1]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[DQ1]] + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[RESHAPE]]) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ2:.*]] = "quantfork.dcast"(%[[Q2]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: return %[[DQ2]] + + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> +} + +// ----- + +// CHECK-LABEL: concatenate_and_composite +// CHECK: %[[ARG0:.*]]: tensor<3x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x2xf32> +func.func @concatenate_and_composite(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x5xf32> { + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[PAD:.*]] = "stablehlo.concatenate"(%[[Q1]], %[[Q2]]) {dimension = 0 : i64} + // CHECK-SAME: (tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, tensor<1x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"(%[[PAD]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: (tensor<4x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<4x5x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CALL]]) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + %1 = "quantfork.dcast"(%0) : (tensor<3x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<3x2xf32> + %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + %3 = "quantfork.dcast"(%2) : (tensor<1x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<1x2xf32> + %4 = "stablehlo.concatenate"(%1, %3) { + dimension = 0 : i64 + } : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> + %5 = "quantfork.qcast"(%4) {volatile} : (tensor<4x2xf32>) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + %6 = "quantfork.dcast"(%5) : (tensor<4x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<4x2xf32> + %7 = "tf.XlaCallModule"(%6) {Sout = [#tf_type.shape<4x5>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<4x2xf32>) -> tensor<4x5xf32> + %8 = "quantfork.qcast"(%7) {volatile} : (tensor<4x5xf32>) -> tensor<4x5x!quant.uniform> + %9 = "quantfork.dcast"(%8) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + return %9 : tensor<4x5xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_convert +func.func @composite_and_convert() -> tensor<1x3xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[CONVERT:.*]] = "stablehlo.convert"(%[[CALL]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CONVERT]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.convert %2 : (tensor<1x3xf32>) -> (tensor<1x3xf32>) + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> +} + +// ----- + +// CHECK-LABEL: pad_and_composite +// CHECK: %[[ARG0:.*]]: tensor<2x3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor +func.func @pad_and_composite(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<5x6xf32> { + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor) -> tensor:f32, 5.000000e-03>> + // CHECK: %[[PAD:.*]] = "stablehlo.pad"(%[[Q1]], %[[Q2]]) + // CHECK-SAME: {edge_padding_high = array, edge_padding_low = array, interior_padding = array} + // CHECK-SAME: (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, tensor:f32, 5.000000e-03>>) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"(%[[PAD]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: (tensor<5x9x!quant.uniform:f32, 5.000000e-03>>) -> tensor<5x6x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CALL]]) : (tensor<5x6x!quant.uniform>) -> tensor<5x6xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> + %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor) -> tensor:f32, 5.000000e-03>> + %3 = "quantfork.dcast"(%2) : (tensor:f32, 5.000000e-03>>) -> tensor + %4 = "stablehlo.pad"(%1, %3) { + edge_padding_low = array, + edge_padding_high = array, + interior_padding = array + }: (tensor<2x3xf32>, tensor) -> tensor<5x9xf32> + %5 = "quantfork.qcast"(%4) {volatile} : (tensor<5x9xf32>) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + %6 = "quantfork.dcast"(%5) : (tensor<5x9x!quant.uniform:f32, 5.000000e-03>>) -> tensor<5x9xf32> + %7 = "tf.XlaCallModule"(%6) {Sout = [#tf_type.shape<5x6>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<5x9xf32>) -> tensor<5x6xf32> + %8 = "quantfork.qcast"(%7) {volatile} : (tensor<5x6xf32>) -> tensor<5x6x!quant.uniform> + %9 = "quantfork.dcast"(%8) : (tensor<5x6x!quant.uniform>) -> tensor<5x6xf32> + return %9 : tensor<5x6xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_select +// CHECK: %[[ARG0:.*]]: tensor<1x3xi1> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x3xf32> +func.func @composite_and_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[SELECT:.*]] = "stablehlo.select"(%[[ARG0]], %[[CALL]], %[[Q1]]) : (tensor<1x3xi1>, tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%2) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = "quantfork.qcast"(%arg1) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.select %arg0, %2, %4 : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %8 = "quantfork.qcast"(%7) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %9 = "quantfork.dcast"(%8) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %9 : tensor<1x3xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_broadcast_in_dim +func.func @composite_and_broadcast_in_dim() -> tensor<2x3x2xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[BROADCAST:.*]] = "stablehlo.broadcast_in_dim"(%[[CALL]]) + // CHECK-SAME: (tensor<1x3x!quant.uniform>) -> tensor<2x3x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[BROADCAST]]) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = "stablehlo.broadcast_in_dim"(%2) { + broadcast_dimensions = dense<[2, 1]>: tensor<2xi64> + } : (tensor<1x3xf32>) -> tensor<2x3x2xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<2x3x2xf32>) -> tensor<2x3x2x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + return %5 : tensor<2x3x2xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_gather +// CHECK: %[[ARG0:.*]]: tensor<2x3x2xi64> +func.func @composite_and_gather(%arg0: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<3x4x2x!quant.uniform> + // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[CALL]], %[[ARG0]]) + // CHECK-SAME: (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi64>) -> tensor<2x3x2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[GATHER]]) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<3x4x2>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<3x4x2xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2xf32> + %3 = "stablehlo.gather"(%2, %arg0) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + return %5 : tensor<2x3x2x2xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_slice +func.func @composite_and_slice() -> tensor<2x2xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<3x4x!quant.uniform> + // CHECK: %[[SLICE:.*]] = "stablehlo.slice"(%[[CALL]]) + // CHECK-SAME: (tensor<3x4x!quant.uniform>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[SLICE]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<3x4>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<3x4xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> + %3 = "stablehlo.slice"(%2) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xf32>) -> tensor<2x2xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + return %5 : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index 3d04c72dec7f7e..745d44282c9e0f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -1,6 +1,9 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file \ +// RUN: -stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops \ +// RUN: | FileCheck %s -// Modules with "main" or "serving_default" should properly run this pass and convert subgraphs into XLACallModuleOp. +// Modules with "main" or "serving_default" should properly run this pass and +// convert subgraphs into XLACallModuleOp. module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { @@ -20,23 +23,23 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> %2 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %3 = "tf.XlaCallModule"(%2, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %3 = "tf.XlaCallModule"(%2, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> %5 = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> %6 = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> %7 = "tf.CustomAggregator"(%4) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - %8 = "tf.XlaCallModule"(%7, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %8 = "tf.XlaCallModule"(%7, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %9 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x64xf32>) -> tensor<1x64xf32> return %9 : tensor<1x64xf32> } // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}}> {_entry_function = @_stablehlo_main_1 // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1" + // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32> // CHECK: } @@ -63,6 +66,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- +// Tests that the subgraph in serving_default excluding the tf.Identity is +// converted to a single XlaCallModuleOp. + module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1654 : i32}, tf_saved_model.semantics} { // CHECK: func private @_stablehlo_main_0 @@ -85,8 +91,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %5 : tensor<1x1024xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>] - // CHECK-SAME: _entry_function = @_stablehlo_main_0 + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]]) // CHECK: return %[[IDENTITY]] // CHECK } @@ -95,8 +100,10 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- +// Tests that the first stablehlo.constant is converted to XlaCallModuleOp. + module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { - // CHECK: func private @_stablehlo_main_ + // CHECK: func private @_stablehlo_main_0 // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> // CHECK: return %[[CONSTANT:.*]] // CHECK: } @@ -105,12 +112,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %3 : tensor<1x3xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_ + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) @@ -127,7 +134,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- -// Tests to confirm that the StableHLO graph is not replaced if "main" or "serving_default" function is in the module. +// Tests to confirm that the StableHLO graph is not replaced if "main" or +// "serving_default" function is not in the module. module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { // CHECK-NOT: func private @_stablehlo_main_ @@ -136,14 +144,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p func.func @random_name(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %3 : tensor<1x3xf32> } // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] // CHECK: } @@ -155,3 +163,97 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %0 : tensor<1x3xf32> } } + +// ----- + +// Tests where StableHLO graph in main has a small constant to be duplicated. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1() -> tensor<1024x3xf32> attributes {_from_xla_call_module} + // CHECK: %[[CONSTANT1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT1:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK-SAME: %[[INPUT1:.*]]: tensor<1024x3xf32>, %[[INPUT2:.*]]: tensor<1024x3xf32> + // CHECK: %[[CONSTANT2:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT1]], %[[CONSTANT2]] : tensor<1024x3xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[INPUT1]], %[[INPUT2]] : tensor<1024x3xf32> + // CHECK: return %[[ADD]], %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output1"]}, tensor<1024x3xf32> {tf_saved_model.index_path = ["output2"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + %4 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %5 = stablehlo.add %3, %4 : tensor<1024x3xf32> + %6 = stablehlo.multiply %3, %0 : tensor<1024x3xf32> + return %5, %6 : tensor<1024x3xf32>, tensor<1024x3xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]]:2 = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]]#0, %[[SUBGRAPH_2]]#1 + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has branches. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1(%[[INPUT:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %[[CONSTANT1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[CONSTANT1]], %[[INPUT]] : tensor<3x3xf32> + // CHECK: return %[[ADD:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK-SAME: (%[[INPUT1:.*]]: tensor<3x3xf32>, %[[INPUT2:.*]]: tensor<3x3xf32>) + // CHECK-SAME: -> tensor<3x3xf32> + // CHECK: %[[CONSTANT2:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT1]], %[[INPUT2]] : tensor<3x3xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ADD]], %[[CONSTANT2]] : tensor<3x3xf32> + // CHECK: return %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<3x3xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<3x3xf32> {tf_saved_model.index_path = ["output1"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + %1 = stablehlo.add %0, %arg0 : tensor<3x3xf32> + %2 = "tf.CustomAggregator"(%1) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + %3 = "tf.XlaCallModule"(%2, %2) {Sout = [#tf_type.shape<3x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + %5 = stablehlo.add %4, %1 : tensor<3x3xf32> + %6 = stablehlo.multiply %5, %0 : tensor<3x3xf32> + return %6 : tensor<3x3xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[INPUT:.*]]) <{Sout = [#tf_type.shape<3x3>], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[SUBGRAPH_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[CUSTOM_AGGREGATOR_1]]) <{Sout = [#tf_type.shape<3x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<3x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc new file mode 100644 index 00000000000000..8c0c2e5fc06116 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc @@ -0,0 +1,177 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" + +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::mlir::quant::common::QuantizationTestBase; + +class IsOpQuantizableStableHloTest : public QuantizationTestBase {}; + +// Quantizable ops: constants +// Non-quantizable ops: normal StableHLO ops and terminators +constexpr absl::string_view module_constant_add = R"mlir( + module { + func.func @constant_add() -> (tensor<3x2xf32>) { + %cst1 = stablehlo.constant dense<2.4> : tensor<3x2xf32> + %cst2 = stablehlo.constant dense<5.7> : tensor<3x2xf32> + %add = stablehlo.add %cst1, %cst2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + func.return %add : tensor<3x2xf32> + } + } +)mlir"; + +// Quantizable ops: XlaCallModule op with "fully_quantizable" attribute and +// same-scale StableHLO ops +// Non-quantizable ops: quantize/dequantize ops +constexpr absl::string_view module_composite_same_scale = R"mlir( + module { + func.func @same_scale_after_composite() -> tensor<3x1xf32> { + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> + } + } +)mlir"; + +// Non-quantizable ops: XlaCallModule op without "fully_quantizable" attribute +constexpr absl::string_view module_composite_no_attr = R"mlir( + module { + func.func @composite_without_attr() -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @non_quantizable_composite, _original_entry_function = "non_quantizable_composite", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + } +)mlir"; + +TEST_F(IsOpQuantizableStableHloTest, ConstantOpQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_constant_add); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "constant_add"); + Operation* constant_op = + FindOperationOfType(test_func); + bool is_constant_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(constant_op); + + EXPECT_TRUE(is_constant_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, TerminatorOpNotQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_constant_add); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "constant_add"); + Operation* return_op = FindOperationOfType(test_func); + bool is_return_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(return_op); + + EXPECT_FALSE(is_return_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, SameScaleOpQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* reshape_op = + FindOperationOfType(test_func); + bool is_reshape_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(reshape_op); + + EXPECT_TRUE(is_reshape_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, NonSameScaleOpNotQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_constant_add); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "constant_add"); + Operation* add_op = FindOperationOfType(test_func); + bool is_add_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(add_op); + + EXPECT_FALSE(is_add_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, ValidXlaCallModuleOpQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* xla_call_module_op = + FindOperationOfType(test_func); + bool is_xla_call_module_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); + + EXPECT_TRUE(is_xla_call_module_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, InvalidXlaCallModuleOpNotQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_no_attr); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "composite_without_attr"); + Operation* xla_call_module_op = + FindOperationOfType(test_func); + bool is_xla_call_module_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); + + EXPECT_FALSE(is_xla_call_module_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, QuantizeDequantizeOpNotQuantizable) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* quantize_op = + FindOperationOfType(test_func); + Operation* dequantize_op = + FindOperationOfType(test_func); + bool is_quantize_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(quantize_op); + bool is_dequantize_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(dequantize_op); + + EXPECT_FALSE(is_quantize_quantizable); + EXPECT_FALSE(is_dequantize_quantizable); +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir new file mode 100644 index 00000000000000..dde460411168d0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir @@ -0,0 +1,53 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-unwrap-xla-call-module-op | FileCheck %s + +// Tests if XlaCallModule op without quantizable trait that calls function with +// '_from_xla_call_module' trait is unwrapped. +// Tests if XlaCallModule op with quantizable trait is not unwrapped. +// Tests if XlaCallModule op without quantizable trait that calls function +// without '_from_xla_call_module' trait is not unwrapped. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1682 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @main_00 + // CHECK: %[[ARG0:.*]]: tensor<10x1x1024xf32> + func.func private @main_00(%arg0: tensor<10x1x1024xf32>) -> tensor<6x5xf32> attributes {tf._original_func_name = "main_0"} { + %0 = "tf.Const"() <{value = dense<1.000000e+00> : tensor<10x1024x3xf32>}> : () -> tensor<10x1024x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %2 = "tf.XlaCallModule"(%1) <{Sout = [#tf_type.shape<3x10>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @main_0, _stablehlo_module_attrs = {}, device = ""} : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + %3 = "tf.XlaCallModule"(%2) <{Sout = [#tf_type.shape<6x5>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @main_1, _stablehlo_module_attrs = {}, device = ""} : (tensor<3x10xf32>) -> tensor<6x5xf32> + return %3 : tensor<6x5xf32> + } + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[CALL1:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[CST]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1 + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-NOT: "tf.XlaCallModule" + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL1]] : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + // CHECK-NEXT: %[[CALL2:.*]] = "tf.XlaCallModule"(%[[RESHAPE]]) + // CHECK-SAME: _entry_function = @main_1 + // CHECK-NOT: _tfl_quant_trait = "fully_quantizable" + // CHECK-NEXT: return %[[CALL2]] + + // CHECK: @composite_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + return %0 : tensor<10x1x3xf32> + } + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK-NEXT: return %[[DOT]] + + // CHECK: @main_0 + func.func private @main_0(%arg0: tensor<10x1x3xf32>) -> tensor<3x10xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.reshape %arg0 : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + return %0 : tensor<3x10xf32> + } + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape + // CHECK-NEXT: return %[[RESHAPE]] + + // CHECK: @main_1 + func.func private @main_1(%arg0: tensor<3x10xf32>) -> tensor<6x5xf32> { + %0 = stablehlo.reshape %arg0 : (tensor<3x10xf32>) -> tensor<6x5xf32> + return %0 : tensor<6x5xf32> + } + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape + // CHECK-NEXT: return %[[RESHAPE]] +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index 3afc42e21d1f6e..a55b1a88e5d964 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -13,14 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" @@ -29,8 +33,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/core/ir/types/dialect.h" int main(int argc, char** argv) { tensorflow::InitMlir y(&argc, &argv); @@ -39,13 +46,15 @@ int main(int argc, char** argv) { mlir::registerTensorFlowPasses(); mlir::quant::stablehlo::registerPasses(); mlir::quant::stablehlo::registerBridgePasses(); + mlir::stablehlo::registerPasses(); + mlir::mhlo::registerAllMhloPasses(); mlir::DialectRegistry registry; registry.insert(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc index bfd9de9ca60d25..eecc96b04be9eb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc @@ -16,20 +16,24 @@ limitations under the License. #include +#include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#define DEBUG_TYPE "uniform-quantized-types" + namespace mlir { namespace quant { UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, MLIRContext& context, - const float scale, - const int8_t zero_point) { + const double scale, + const int64_t zero_point) { return UniformQuantizedType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/8), @@ -38,8 +42,8 @@ UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, } UniformQuantizedType CreateI32F32UniformQuantizedType( - const Location loc, MLIRContext& context, const float scale, - const int32_t zero_point) { + const Location loc, MLIRContext& context, const double scale, + const int64_t zero_point) { return UniformQuantizedType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/32), @@ -49,8 +53,8 @@ UniformQuantizedType CreateI32F32UniformQuantizedType( } UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( - const Location loc, MLIRContext& context, const ArrayRef scales, - const ArrayRef zero_points, const int quantization_dimension) { + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension) { return UniformQuantizedPerAxisType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/8), @@ -60,5 +64,106 @@ UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( /*storageTypeMax=*/llvm::maxIntN(8)); } +bool IsStorageTypeI8(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/8); +} + +bool IsStorageTypeI32(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/32); +} + +bool IsExpressedTypeF32(const QuantizedType quantized_type) { + const Type expressed_type = quantized_type.getExpressedType(); + return expressed_type.isa(); +} + +bool IsI8F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + type.dyn_cast_or_null(); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +bool IsI8F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + type.dyn_cast_or_null(); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + +bool IsI32F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + type.dyn_cast_or_null(); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { + if (storage_type.getWidth() == 8 || + (storage_type.isSigned() && storage_type.getWidth() == 16)) { + return true; + } + LLVM_DEBUG(llvm::dbgs() + << "Uniform quantize / dequantize op only supports ui8, i8 or " + "i16 for the storage type of uniform quantized type. Got: " + << storage_type << ".\n"); + return false; +} + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h index 68774b2ecb876b..d04dc5a5761b8f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h @@ -18,8 +18,10 @@ limitations under the License. #include #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -32,8 +34,8 @@ namespace quant { // values can be non-zero values. UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, MLIRContext& context, - float scale, - int8_t zero_point); + double scale, + int64_t zero_point); // Creates a `UniformQuantizedType` with the given `scale` and `zero_point` // values. The produced type has f32 as its expressed type and i32 as its @@ -42,8 +44,8 @@ UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, // non-zero values. UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, MLIRContext& context, - float scale, - int32_t zero_point); + double scale, + int64_t zero_point); // Creates a `UniformQuantizedPerAxisType` with the given `scales` and // `zero_points` values. The produced type has f32 as its expressed type and @@ -51,8 +53,30 @@ UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, // storage value, i.e. [-128, 127]. Assumes asymmetric quantization, meaning the // zero point values can be non-zero values. UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( - Location loc, MLIRContext& context, ArrayRef scales, - ArrayRef zero_points, int quantization_dimension); + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension); + +bool IsStorageTypeI8(QuantizedType quantized_type); + +bool IsStorageTypeI32(QuantizedType quantized_type); + +bool IsExpressedTypeF32(QuantizedType quantized_type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedPerAxisType(Type type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedType(Type type); + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc index 0888bfa8d22908..f33b322cfbd9e4 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -30,8 +31,10 @@ namespace quant { namespace { using ::testing::ElementsAreArray; +using ::testing::NotNull; +using ::testing::Test; -class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test { +class CreateI8F32UniformQuantizedTypeTest : public Test { protected: CreateI8F32UniformQuantizedTypeTest() : ctx_() { ctx_.loadDialect(); @@ -40,7 +43,7 @@ class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test { MLIRContext ctx_; }; -TEST_F(CreateI8F32UniformQuantizedTypeTest, HasI8StorageType) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, I8StorageTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -48,7 +51,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasI8StorageType) { EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, HasF32ExpressedType) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, F32ExpressedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -56,7 +59,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasF32ExpressedType) { EXPECT_TRUE(quantized_type.getExpressedType().isF32()); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, SignedQuantizedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -64,7 +67,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) { EXPECT_TRUE(quantized_type.isSigned()); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, SotrageTypeMinMaxEqualToI8MinMax) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, StorageTypeMinMaxEqualToI8MinMax) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -82,7 +85,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { EXPECT_EQ(quantized_type.getZeroPoint(), 99); } -class CreateI32F32UniformQuantizedTypeTest : public ::testing::Test { +class CreateI32F32UniformQuantizedTypeTest : public Test { protected: CreateI32F32UniformQuantizedTypeTest() : ctx_() { ctx_.loadDialect(); @@ -91,7 +94,7 @@ class CreateI32F32UniformQuantizedTypeTest : public ::testing::Test { MLIRContext ctx_; }; -TEST_F(CreateI32F32UniformQuantizedTypeTest, HasI32StorageType) { +TEST_F(CreateI32F32UniformQuantizedTypeTest, I32StorageTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -99,7 +102,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasI32StorageType) { EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); } -TEST_F(CreateI32F32UniformQuantizedTypeTest, HasF32ExpressedType) { +TEST_F(CreateI32F32UniformQuantizedTypeTest, F32ExpressedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -107,7 +110,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasF32ExpressedType) { EXPECT_TRUE(quantized_type.getExpressedType().isF32()); } -TEST_F(CreateI32F32UniformQuantizedTypeTest, IsSigned) { +TEST_F(CreateI32F32UniformQuantizedTypeTest, SignedQuantizedTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -116,7 +119,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, IsSigned) { } TEST_F(CreateI32F32UniformQuantizedTypeTest, - SotrageTypeMinMaxEqualToI32MinMax) { + StorageTypeMinMaxEqualToI32MinMax) { const UniformQuantizedType quantized_type = CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -136,7 +139,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { EXPECT_EQ(quantized_type.getZeroPoint(), 1111); } -class CreateI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { +class CreateI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: CreateI8F32UniformQuantizedPerAxisTypeTest() : ctx_() { ctx_.loadDialect(); @@ -145,34 +148,35 @@ class CreateI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { MLIRContext ctx_; }; -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasI8StorageType) { +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, I8StorageTypeSucceeds) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasF32ExpressedType) { +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, F32ExpressedTypeSucceeds) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_TRUE(quantized_type.getExpressedType().isF32()); } -TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, IsSigned) { +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, + SignedQuantizedTypeSucceeds) { const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_TRUE(quantized_type.isSigned()); @@ -183,8 +187,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); EXPECT_EQ(quantized_type.getStorageTypeMin(), -128); @@ -196,8 +200,8 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{1.0, 1.0}, - /*zero_points=*/SmallVector{0, 0}, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/3); EXPECT_EQ(quantized_type.getQuantizedDimension(), 3); @@ -208,14 +212,182 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, const UniformQuantizedPerAxisType quantized_type = CreateI8F32UniformQuantizedPerAxisType( UnknownLoc::get(&ctx_), ctx_, - /*scales=*/SmallVector{8.0, 9.0}, - /*zero_points=*/SmallVector{98, 99}, + /*scales=*/SmallVector{8.0, 9.0}, + /*zero_points=*/SmallVector{98, 99}, /*quantization_dimension=*/0); EXPECT_THAT(quantized_type.getScales(), ElementsAreArray({8.0, 9.0})); EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); } +class IsI8F32UniformQuantizedTypeTest : public Test { + protected: + IsI8F32UniformQuantizedTypeTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsI8F32UniformQuantizedTypeTest, I8F32UniformQuantizedTypeSucceeds) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsI8F32UniformQuantizedType(qi8_type)); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, StorageTypeI8Succeeds) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsStorageTypeI8(qi8_type)); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsExpressedTypeF32(qi8_type)); +} + +class IsI8F32UniformQuantizedPerAxisTypeTest : public Test { + protected: + IsI8F32UniformQuantizedPerAxisTypeTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, + I8F32UniformQuantizedPerAxisTypeSucceeds) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_TRUE(IsI8F32UniformQuantizedPerAxisType(qi8_per_axis_type)); + EXPECT_FALSE(IsI8F32UniformQuantizedType(qi8_per_axis_type)); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_THAT(qi8_per_axis_type.dyn_cast_or_null(), + NotNull()); +} + +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_TRUE(IsStorageTypeI8(qi8_per_axis_type)); +} + +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_TRUE(IsExpressedTypeF32(qi8_per_axis_type)); +} + +class IsI32F32UniformQuantizedTypeTest : public Test { + protected: + IsI32F32UniformQuantizedTypeTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsI32F32UniformQuantizedTypeTest, I32F32UniformQuantizedTypeSucceeds) { + const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { + const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, StorageTypeI32Succeeds) { + const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsStorageTypeI32(qi32_type)); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { + const UniformQuantizedType qi32_per_axis_type = + quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); +} + +class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public Test { + protected: + IsSupportedByTfliteQuantizeOrDequantizeOpsTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI8Succeeds) { + auto qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/true), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( + dyn_cast_or_null(qi8_type.getStorageType()))); +} + +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI16Succeeds) { + auto qi16_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getIntegerType(16, /*isSigned=*/true), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( + dyn_cast_or_null(qi16_type.getStorageType()))); +} + +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeUI8Succeeds) { + auto qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/false), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( + dyn_cast_or_null(qi8_type.getStorageType()))); +} + } // namespace } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc index 4dcdb637e1b430..a864ee556ff5af 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils_test.cc @@ -24,14 +24,23 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -TEST(UtilsTest, IsStablehloOp) { - MLIRContext ctx; - OpBuilder b(&ctx); - ctx.loadDialect(); +using ::testing::Test; +class StablehloTypeUtilsTest : public Test { + protected: + StablehloTypeUtilsTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(StablehloTypeUtilsTest, ValidStablehloOpSucceeds) { mlir::stablehlo::ConstantOp constant_op = - b.create(b.getUnknownLoc(), - b.getI32IntegerAttr(0)); + builder_.create( + builder_.getUnknownLoc(), builder_.getI32IntegerAttr(0)); EXPECT_TRUE(IsStablehloOp(constant_op)); constant_op->erase(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc index 03495d3ddae7aa..87d71438cf4e7c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc @@ -90,7 +90,7 @@ std::unique_ptr CreateContext() { return context; } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToUQ8) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToUQ8Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get( {2, 2}, quant::UniformQuantizedType::get( @@ -109,7 +109,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToUQ8) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToInt8) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToInt8Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get({2, 2}, IntegerType::get(context.get(), 8)); @@ -125,7 +125,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint8ToInt8) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToUQ32) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToUQ32Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get( {2, 2}, @@ -145,7 +145,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToUQ32) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToInt32) { +TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToInt32Succeeds) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get({2, 2}, IntegerType::get(context.get(), 32)); @@ -161,7 +161,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, Qint32ToInt32) { EXPECT_EQ(dense_attr->getValues()[3], 4); } -TEST(GetDenseAttrFromTensorProtoAttrTest, UnsupportedQint16) { +TEST(GetDenseAttrFromTensorProtoAttrTest, UnsupportedQint16Fails) { auto context = CreateContext(); TensorType result_tensor_type = RankedTensorType::get({2, 2}, IntegerType::get(context.get(), 16)); @@ -170,7 +170,7 @@ TEST(GetDenseAttrFromTensorProtoAttrTest, UnsupportedQint16) { GetDenseAttrFromTensorProtoAttr(GetQint16Tensor(), result_tensor_type))); } -TEST(IsTFQintTypeTest, IsTFQintType) { +TEST(IsTFQintTypeTest, ValidTFQintTypeSucceeds) { auto context = CreateContext(); EXPECT_TRUE(IsTFQintType(TF::Qint8Type::get(context.get()))); @@ -183,7 +183,7 @@ TEST(IsTFQintTypeTest, IsTFQintType) { EXPECT_FALSE(IsTFQintType(TF::Float8E5M2RefType::get(context.get()))); } -TEST(GetIntTypeFromTFQintTest, GetIntTypeFromTFQint) { +TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) { auto context = CreateContext(); auto type = GetIntTypeFromTFQint(TF::Qint8Type::get(context.get())); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index f5c170122977ba..c973a4fed16bb0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -51,24 +51,6 @@ genrule( tools = ["gen_quantized_function_library"], ) -cc_library( - name = "pass_utils", - srcs = [ - "passes/utils.cc", - ], - hdrs = [ - "passes/utils.h", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - ":quantization_options_proto_cc", - "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", - "//tensorflow/compiler/mlir/tensorflow", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - cc_library( name = "manipulate_model_attr", srcs = [ @@ -117,12 +99,11 @@ td_library( "passes/quantize_composite_functions.td", "passes/replace_cast_hacks_with_tf_xla_ops.td", "passes/tf_quant_ops.td", - "passes/utils.td", ], compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", - "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils_td_files", + "//tensorflow/compiler/mlir/quantization/common:quant_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", "@llvm-project//mlir:ArithOpsTdFiles", "@llvm-project//mlir:FuncTdFiles", @@ -411,7 +392,6 @@ cc_library( ":lift_quantizable_spots_as_functions_inc_gen", ":manipulate_model_attr", ":optimize_inc_gen", - ":pass_utils", ":post_quantize_inc_gen", ":prepare_lifting_inc_gen", ":prepare_quantize_inc_gen", @@ -425,6 +405,8 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:const_op_size", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", @@ -432,7 +414,6 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_quantize_op", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:fake_quant_utils", - "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_to_uniform_attribute_utils", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_to_xla_attribute_utils", "//tensorflow/compiler/mlir/tensorflow", @@ -543,6 +524,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo:mhlo_passes", ], @@ -616,5 +598,6 @@ tf_cc_binary( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index aec612c95b7b62..34260f6e75e1c4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -12,6 +12,7 @@ load( "get_compatible_with_portable", "tf_kernel_library", "tf_py_strict_test", + "tf_python_pybind_extension", ) load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load( @@ -35,7 +36,6 @@ cc_library( srcs = ["calibrator_singleton.cc"], hdrs = ["calibrator_singleton.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:private"], deps = [ ":calibration_statistics_collector_average_min_max", ":calibration_statistics_collector_base", @@ -223,9 +223,9 @@ tf_py_strict_test( deps = [ ":calibration_statistics_proto_py", ":gen_custom_aggregator_op_wrapper", + ":pywrap_calibration", "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", - "//tensorflow/compiler/mlir/quantization/tensorflow/python:pywrap_quantize_model", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", @@ -233,3 +233,17 @@ tf_py_strict_test( "//tensorflow/python/platform:client_testlib", ], ) + +tf_python_pybind_extension( + name = "pywrap_calibration", + srcs = ["pywrap_calibration.cc"], + pytype_srcs = ["pywrap_calibration.pyi"], + deps = [ + ":calibration_statistics_proto_cc", + ":calibrator_singleton", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@pybind11", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc index d3b0475a3dbc74..95e89a7c573c91 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include #include #include #include @@ -107,6 +108,11 @@ std::optional CalibratorSingleton::GetStatistics( return instance.id_to_collector_[id_str]->GetStatistics(); } +int64_t CalibratorSingleton::IssueNewId() { + CalibratorSingleton& instance = GetInstance(); + return instance.next_id_++; +} + void CalibratorSingleton::AssignIfNotExists( std::string id_str, const CalibrationOptions& calib_opts) { CalibratorSingleton& instance = GetInstance(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h index 138352bfcf3d53..38432b01a5a3da 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_ +#include +#include #include #include #include @@ -35,6 +37,7 @@ namespace calibrator { using tensorflow::quantization::CalibrationOptions; +// TODO: b/315084876 - Move to stablehlo quantizer directory. class CalibratorSingleton { public: // Clears the collected information. @@ -65,12 +68,20 @@ class CalibratorSingleton { static std::optional GetStatistics( absl::string_view id); + // Issues a new node ID that uniquely identifies a set of calibration + // statistics. + static int64_t IssueNewId(); + private: static CalibratorSingleton& GetInstance(); static absl::Mutex lock_; static void AssignIfNotExists(std::string id_str, const CalibrationOptions& calib_opts); + // Indicates the next id for a set of calibration statistics. For every new ID + // issued this will be incremented atomically. + std::atomic next_id_{0}; + absl::flat_hash_map> id_to_collector_; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc index d6e85c33da8c76..d58dbb838be792 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include #include #include @@ -201,6 +202,12 @@ TEST(CalibratorSingletonTest, SimpleAverageMinMax) { EXPECT_EQ(statistics.value().average_min_max_statistics().num_samples(), 3); } +TEST(CalibratorSingletonTest, IssueNewIdGeneratesNewId) { + const int64_t id = CalibratorSingleton::IssueNewId(); + const int64_t next_id = CalibratorSingleton::IssueNewId(); + EXPECT_NE(id, next_id); +} + } // namespace } // namespace calibrator } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py index a9d1ccacbb7533..5818017a155b58 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py @@ -16,16 +16,15 @@ import tensorflow # pylint: disable=unused-import -# pylint: disable=invalid-import-order,g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stat_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op_wrapper -from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model as quantize_model_wrapper +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stat_pb2 _CalibrationMethod = quant_opts_pb2.CalibrationOptions.CalibrationMethod @@ -37,8 +36,8 @@ def setUp(self): ops.disable_eager_execution() def testBypassAndMinMax(self): - with self.test_session(): - quantize_model_wrapper.clear_calibrator() + with self.session(): + pywrap_calibration.clear_calibrator() input_tensor = array_ops.constant( [1.0, 2.0, 3.0, 4.0, 5.0], dtypes.float32 ) @@ -51,7 +50,7 @@ def testBypassAndMinMax(self): self.assertAllEqual(self.evaluate(aggregator), [1.0, 2.0, 3.0, 4.0, 5.0]) statistics: calib_stat_pb2.CalibrationStatistics = ( - quantize_model_wrapper.get_statistics_from_calibrator('1') + pywrap_calibration.get_statistics_from_calibrator('1') ) min_val = statistics.min_max_statistics.global_min @@ -60,8 +59,8 @@ def testBypassAndMinMax(self): self.assertAllEqual((min_val, max_val), (1.0, 5.0)) def testTwoIdentities(self): - with self.test_session(): - quantize_model_wrapper.clear_calibrator() + with self.session(): + pywrap_calibration.clear_calibrator() input_tensor1 = array_ops.constant( [1.0, 2.0, 3.0, 4.0, 5.0], dtypes.float32 ) @@ -84,21 +83,21 @@ def testTwoIdentities(self): ) statistics: calib_stat_pb2 = ( - quantize_model_wrapper.get_statistics_from_calibrator('2') + pywrap_calibration.get_statistics_from_calibrator('2') ) min_val = statistics.min_max_statistics.global_min max_val = statistics.min_max_statistics.global_max self.assertAllEqual((min_val, max_val), (1.0, 5.0)) statistics: calib_stat_pb2 = ( - quantize_model_wrapper.get_statistics_from_calibrator('3') + pywrap_calibration.get_statistics_from_calibrator('3') ) min_val = statistics.min_max_statistics.global_min max_val = statistics.min_max_statistics.global_max self.assertAllEqual((min_val, max_val), (-5.0, -1.0)) def testClearData(self): - with self.test_session(): - quantize_model_wrapper.clear_calibrator() + with self.session(): + pywrap_calibration.clear_calibrator() input_tensor1 = array_ops.constant( [1.0, 2.0, 3.0, 4.0, 5.0], dtypes.float32 ) @@ -121,33 +120,33 @@ def testClearData(self): ) statistics: calib_stat_pb2 = ( - quantize_model_wrapper.get_statistics_from_calibrator('4') + pywrap_calibration.get_statistics_from_calibrator('4') ) min_val = statistics.min_max_statistics.global_min max_val = statistics.min_max_statistics.global_max self.assertAllEqual((min_val, max_val), (1.0, 5.0)) statistics: calib_stat_pb2 = ( - quantize_model_wrapper.get_statistics_from_calibrator('5') + pywrap_calibration.get_statistics_from_calibrator('5') ) min_val = statistics.min_max_statistics.global_min max_val = statistics.min_max_statistics.global_max self.assertAllEqual((min_val, max_val), (-5.0, -1.0)) - quantize_model_wrapper.clear_data_from_calibrator('4') + pywrap_calibration.clear_data_from_calibrator('4') with self.assertRaises(ValueError): - quantize_model_wrapper.get_statistics_from_calibrator('4') + pywrap_calibration.get_statistics_from_calibrator('4') statistics: calib_stat_pb2 = ( - quantize_model_wrapper.get_statistics_from_calibrator('5') + pywrap_calibration.get_statistics_from_calibrator('5') ) min_val = statistics.min_max_statistics.global_min max_val = statistics.min_max_statistics.global_max self.assertAllEqual((min_val, max_val), (-5.0, -1.0)) def testBypassAndAverageMinMax(self): - with self.test_session(): - quantize_model_wrapper.clear_calibrator() + with self.session(): + pywrap_calibration.clear_calibrator() input_tensor1 = array_ops.constant( [-50.0, -25.0, 0.0, 25.0, 50.0], dtypes.float32 ) @@ -173,7 +172,7 @@ def testBypassAndAverageMinMax(self): ) statistics: calib_stat_pb2 = ( - quantize_model_wrapper.get_statistics_from_calibrator('6') + pywrap_calibration.get_statistics_from_calibrator('6') ) min_sum = statistics.average_min_max_statistics.min_sum diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.cc new file mode 100644 index 00000000000000..8f7c4e30457a2e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.cc @@ -0,0 +1,91 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" + +namespace py = ::pybind11; + +namespace { + +using ::tensorflow::calibrator::CalibrationStatistics; +using ::tensorflow::calibrator::CalibratorSingleton; + +// Retrieves collected statistics of a `CustomAggregator` node from the +// singleton. `id` is the identifier of the `CustomAggregator`. +CalibrationStatistics GetStatisticsFromCalibrator(const absl::string_view id) { + std::optional statistics = + CalibratorSingleton::GetStatistics(id); + + if (!statistics.has_value()) { + throw py::value_error(absl::StrFormat( + "Calibrated data does not exist. Cannot find statistics." + "value for id: '%s'", + id)); + } + + return *statistics; +} + +} // namespace + +PYBIND11_MODULE(pywrap_calibration, m) { + // Allows type casting protobuf objects. + pybind11_protobuf::ImportNativeProtoCasters(); + + m.doc() = "Defines functions for interacting with CalibratorSingleton."; + + m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange + "clear_calibrator", + []() -> void + // LINT.ThenChange(pywrap_calibration.pyi:clear_calibrator) + { CalibratorSingleton::ClearCollectedInformation(); }, + R"pbdoc( + Clears the collected metrics from the calibrator. + )pbdoc"); + m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange + "clear_data_from_calibrator", + [](const absl::string_view id) -> void + // LINT.ThenChange(pywrap_calibration.pyi:clear_data_from_calibrator) + { CalibratorSingleton::ClearData(id); }, + R"pbdoc( + Clears the collected data of the given id from calibrator. + )pbdoc", + py::arg("id")); + m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange + "get_statistics_from_calibrator", + [](const absl::string_view id) -> CalibrationStatistics { + // LINT.ThenChange(pywrap_calibration.pyi:get_statistics_from_calibrator) + return GetStatisticsFromCalibrator(id); + }, + R"pbdoc( + Returns the proto CalibrationStatistics given id from calibrator. + )pbdoc", + py::arg("id")); +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.pyi new file mode 100644 index 00000000000000..5d859fee947364 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.pyi @@ -0,0 +1,32 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 + +# LINT.IfChange(clear_calibrator) +def clear_calibrator() -> None: ... + +# LINT.ThenChange() + +# LINT.IfChange(clear_data_from_calibrator) +def clear_data_from_calibrator(id: bytes) -> None: ... + +# LINT.ThenChange() + +# LINT.IfChange(get_statistics_from_calibrator) +def get_statistics_from_calibrator( + id: bytes, +) -> calibration_statistics_pb2.CalibrationStatistics: ... + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 635f71ec59fb6e..574eb7be350d4e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -12,8 +12,7 @@ package( # By default, these targets should only be used within the quantization library. default_visibility = [ "//learning/brain/mlir/quantization:__subpackages__", - "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", - "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + "//tensorflow/compiler/mlir/quantization:__subpackages__", ], licenses = ["notice"], ) @@ -126,27 +125,6 @@ tf_cc_test( ], ) -cc_library( - name = "status_macro", - hdrs = ["status_macro.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:macros", - ], -) - -tf_cc_test( - name = "status_macro_test", - srcs = ["status_macro_test.cc"], - deps = [ - ":status_macro", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/status", - ], -) - cc_library( name = "run_passes", srcs = ["run_passes.cc"], @@ -157,6 +135,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:statusor", @@ -173,7 +152,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils", + "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:constant_fold_utils", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc index 64d89dad2e27f0..565adebfe52300 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.cc @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro.h deleted file mode 100644 index 5dc784dc8a67c8..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_STATUS_MACRO_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_STATUS_MACRO_H_ - -#include "tsl/platform/macros.h" - -namespace tensorflow { -namespace quantization { - -// Similar to TF_RETURN_IF_ERROR but used for `absl::Status`. -#define TF_QUANT_RETURN_IF_ERROR(expr) \ - do { \ - ::absl::Status _status = (expr); \ - if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ - } while (0) - -} // namespace quantization -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_STATUS_MACRO_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro_test.cc deleted file mode 100644 index 1e9de6b43d74ed..00000000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro.h" - -#include "absl/status/status.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace quantization { -namespace { - -using ::testing::Eq; - -TEST(TfQuantReturnIfErrorTest, DoesNotReturnIfOk) { - const auto returned_status = []() -> absl::Status { - TF_QUANT_RETURN_IF_ERROR(absl::OkStatus()); - return absl::InternalError("Expected"); - }(); - - EXPECT_THAT(returned_status.message(), Eq("Expected")); -} - -TEST(TfQuantReturnIfErrorTest, ReturnsIfOk) { - const auto returned_status = []() -> absl::Status { - TF_QUANT_RETURN_IF_ERROR(absl::InternalError("Expected")); - return absl::OkStatus(); - }(); - - EXPECT_THAT(returned_status.message(), Eq("Expected")); -} - -} // namespace -} // namespace quantization -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD index d556e09ee9bba2..fa201ff6a716bc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -46,7 +46,6 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_quantize_op_utils", "//tensorflow/compiler/mlir/tensorflow", @@ -65,7 +64,7 @@ tf_cc_test( srcs = ["tf_quantize_op_test.cc"], deps = [ ":tf_quantize_op", - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc index 971237c5175eb7..6fea7f1cc4778a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc index 8ab909ba432231..4a205648a777e6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td index ace1a77e6f32ae..80c65560aa1421 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td @@ -17,8 +17,8 @@ include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" //===----------------------------------------------------------------------===// // Pattern rules for converting bfloat16 operations to fp32 conversions. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc index 994ebea795d079..d23a0f8d3a7af2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc @@ -39,7 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.td index c2046a3fd70d47..2e6e92ba467fda 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.td @@ -17,9 +17,8 @@ include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" -include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" -include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" // Only handles the case where precision config is default. def IsPrecisionEmpty : diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td index 945f992188642f..9d39d89c42ae53 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.td @@ -15,7 +15,7 @@ limitations under the License. include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" // Combines the two variadic arguments ($in_tensors and $captured_tensors). def GetBatchFunctionOpArgOperands: diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index f4994cd9c4eaea..68014ebec46605 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -32,10 +32,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -196,37 +196,54 @@ class AddCustomAggregationOp : public RewritePattern { // Return early if the given operator is the custom aggregator op. if (dyn_cast_or_null(op)) return failure(); - // Return early if the given op is a non-quantizable op. - auto call_op = dyn_cast_or_null(op); - if (call_op && !op->hasAttr(kQuantTraitAttrName)) { - return failure(); - } - - bool mutated = false; - for (Value input : op->getOperands()) { - Type element_type = getElementTypeOrSelf(input.getType()); - // Non-float cases won't be calibrated. - if (!element_type.isF32()) { - continue; - } - - // Skip when the given operator is under the quantizable spot. - if (IsInLiftedFunc(op)) { - continue; - } - - // Skip when there is any already existing CustomAggregatorOp found. - Operation *defining_op = input.getDefiningOp(); - if (dyn_cast_or_null(defining_op)) { - continue; + // The CustomAggregatorOp is only added after quantizable values. + SmallVector quantizable_values; + if (isCallToLiftedFunction(op)) { + // Quantize inputs of quantizable composite functions. + for (Value input : op->getOperands()) { + Type element_type = getElementTypeOrSelf(input.getType()); + // Non-float cases won't be calibrated. + if (!element_type.isF32()) { + continue; + } + + // Skip when there is any already existing CustomAggregatorOp found. + Operation *defining_op = input.getDefiningOp(); + if (dyn_cast_or_null(defining_op)) { + continue; + } + + // Skip calibration when the given operand comes from a constant. + if (defining_op != nullptr && + defining_op->hasTrait()) { + continue; + } + + quantizable_values.push_back(input); } - - // Skip calibration when the given operand comes from a constant. - if (defining_op != nullptr && - defining_op->hasTrait()) { - continue; + } else { + // Quantize output of fully quantizable composite functions. + for (Value input : op->getOperands()) { + auto defining_op = input.getDefiningOp(); + if (!isCallToLiftedFunction(defining_op)) { + continue; + } + + // Do not add CustomAggregatorOp after Gather since it is a weight-only + // quantizable op. + if (auto call_op = + dyn_cast_or_null(defining_op)) { + StringRef function_name = + call_op.getFAttr().cast().getValue(); + if (function_name.contains("gather")) continue; + } + + quantizable_values.push_back(input); } + } + if (quantizable_values.empty()) return failure(); + for (Value value : quantizable_values) { // ID attribute will have empty value for now. SmallVector attributes{ rewriter.getNamedAttr("id", rewriter.getStringAttr("")), @@ -248,24 +265,32 @@ class AddCustomAggregationOp : public RewritePattern { }; // Insert custom aggregation op between operand and operator. - rewriter.setInsertionPointAfterValue(input); + rewriter.setInsertionPointAfterValue(value); Operation *aggregator_op = rewriter.create( - op->getLoc(), input.getType(), input, attributes); + op->getLoc(), value.getType(), value, attributes); Value aggregator_op_result = aggregator_op->getOpResult(0); - input.replaceAllUsesWith(aggregator_op_result); - aggregator_op->replaceUsesOfWith(aggregator_op_result, input); - - // Mark mutated. - mutated = true; + value.replaceAllUsesWith(aggregator_op_result); + aggregator_op->replaceUsesOfWith(aggregator_op_result, value); } - // Return failure when there is no matching operand. - return mutated ? success() : failure(); + return success(); } private: CalibrationOptions calib_opts_; + + // Whether the op is a call op to lifted composite function. + bool isCallToLiftedFunction(Operation *op) const { + if (!op) return false; + if (isa(op)) return true; + + TF::PartitionedCallOp call_op = dyn_cast_or_null(op); + return call_op && call_op->hasAttrOfType(kQuantTraitAttrName) && + call_op->getAttrOfType(kQuantTraitAttrName) + .getValue() + .equals(QuantTraitValues[QuantizationTrait::FullyQuantizable]); + } }; void InsertCustomAggregationOpsPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc index 0e6ce592ea0b8e..b471b7910d0eef 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc index 38eabad77a9052..1f94cdfff15754 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -38,12 +39,12 @@ limitations under the License. #include "re2/re2.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -51,14 +52,12 @@ namespace mlir { namespace quant { namespace { -using QuantizationOptions = tensorflow::quantization::QuantizationOptions; -using QuantizationMethod = tensorflow::quantization::QuantizationMethod; -using QuantizationComponentSpec = - tensorflow::quantization::QuantizationComponentSpec; -using UnitWiseQuantizationSpec = - tensorflow::quantization::UnitWiseQuantizationSpec; using QuantizationUnit = - tensorflow::quantization::UnitWiseQuantizationSpec::QuantizationUnit; + ::tensorflow::quantization::UnitWiseQuantizationSpec::QuantizationUnit; +using ::tensorflow::quantization::QuantizationComponentSpec; +using ::tensorflow::quantization::QuantizationMethod; +using ::tensorflow::quantization::QuantizationOptions; +using ::tensorflow::quantization::UnitWiseQuantizationSpec; class LiftQuantizableSpotsAsFunctionsPass : public PassWrapper().getNumElements(); + if (num_elements < quant_options_.min_num_elements_for_weights()) { + return absl::InternalError( + "The params of Gather have fewer number of elements than " + "the `min_num_elements_for_weights`."); + } } // Disable quantization if the quantization method is NO_QUANTIZE. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td index 1628cedf99e9cd..d56ee05dc071dc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td @@ -17,8 +17,8 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" //===----------------------------------------------------------------------===// // Helper functions. @@ -62,7 +62,7 @@ def LiftDepthwiseConv : Pat< [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; def LiftMatMul : Pat< - (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b), + (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), (LiftAsTFPartitionedCall<"composite_matmul_fn"> (ArgumentList $a, $b), (ResultList $res), @@ -84,7 +84,7 @@ def LiftConv3D : Pat< [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; def LiftBatchMatMul : Pat< - (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y), + (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), (LiftAsTFPartitionedCall<"composite_batch_matmul_fn"> (ArgumentList $x, $y), (ResultList $res), @@ -142,7 +142,7 @@ def LiftConv2dWithBias : Pat< def LiftMatmulWithBias : Pat< (TF_BiasAddOp:$res - (TF_MatMulOp $a, $b, $transpose_a, $transpose_b), + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), $bias, IsDataFormatNHWC:$bias_data_format), (LiftAsTFPartitionedCall<"composite_matmul_with_bias_fn"> (ArgumentList $a, $b, $bias), @@ -157,7 +157,7 @@ def LiftMatmulWithBias : Pat< def LiftMatmulWithReshapeAndBias : Pat< (TF_BiasAddOp:$res (TF_ReshapeOp:$out - (TF_MatMulOp $a, $b, $transpose_a, $transpose_b), + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), $shape), $bias, IsDataFormatNHWC:$bias_data_format), (LiftAsTFPartitionedCall<"composite_matmul_with_reshape_and_bias_fn"> @@ -184,7 +184,7 @@ def LiftConv3dWithBias : Pat< def LiftBatchMatMulWithBias : Pat< (TF_BiasAddOp:$res - (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y), + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), $bias, IsDataFormatNHWC:$bias_data_format), (LiftAsTFPartitionedCall<"composite_batch_matmul_with_bias_fn"> (ArgumentList $x, $y, $bias), @@ -276,7 +276,7 @@ multiclass LiftCompositeOpsWithActivation (ArgumentList $a, $b), (ResultList $res), @@ -288,7 +288,7 @@ multiclass LiftCompositeOpsWithActivation (ArgumentList $a, $b, $bias), @@ -328,7 +328,7 @@ multiclass LiftCompositeOpsWithActivation (ArgumentList $x, $y), (ResultList $res), @@ -340,7 +340,7 @@ multiclass LiftCompositeOpsWithActivation (ArgumentList $x, $y, $bias), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc index f86dcf3b3287ed..3e631835cd0ee5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc @@ -26,10 +26,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -37,7 +37,8 @@ namespace mlir { namespace quant { namespace { -using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using QuantMethod = + ::tensorflow::quantization::QuantizationMethod::PresetMethod; class LiftQuantizableSpotsAsFunctionsDRQPass : public PassWrapper; def LiftMatMul : Pat< - (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b), + (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), (LiftAsTFPartitionedCall<"composite_matmul_fn"> (ArgumentList $a, $b), (ResultList $res), @@ -83,7 +83,7 @@ def LiftConv3D : Pat< [(IsNotInLiftedFunc $res), (IsConstTensor $filter)], [], (addBenefit 1)>; def LiftBatchMatMul : Pat< - (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y), + (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), (LiftAsTFPartitionedCall<"composite_batch_matmul_fn"> (ArgumentList $x, $y), (ResultList $res), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc index 1ba9d68347e2ce..b459bbcd901125 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc @@ -21,7 +21,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir::quant { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td index 2348ac80b845f1..c40902d283e8cc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td @@ -17,8 +17,8 @@ include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" // Remove redundant `CastOp` to int8 if the input is properly clipped. def RemoveRedundantCastOps : Pat< diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index 7deed9306fcf88..7300cb3996b131 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td index 5d879adea90a50..7e00f588f9dc71 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td @@ -18,9 +18,9 @@ include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" // Re-orders the Identity op following a quantized composite function. This // allows the QuantizeCompositeFunctionsPass to merge the DequantizeCast with diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index e0fb1224d5540a..886a27011b1825 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -37,9 +37,9 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" @@ -182,7 +182,7 @@ Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, return ConstantFoldOpIfPossible(reshape_op).front(); } -// Checks if a value can be symetrically quantized. +// Checks if a value can be symmetrically quantized. bool CanBeSymmetricallyQuantized(Value weight) { auto dq_op = weight.getDefiningOp(); if (!dq_op) return true; @@ -215,7 +215,7 @@ SmallVector MultiplyTwoArrays(ArrayRef a, ArrayRef b) { } // Multiplies the value followed by a FakeQuant op and adjusts the quantization -// params. This funtion only supports symetrically quantized values. +// params. This function only supports symmetrically quantized values. Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, Value multiplier) { auto dq_op = value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index f88644a378dd9a..30e298dd6e7048 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -17,8 +17,8 @@ include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" include "mlir/Dialect/Arith/IR/ArithOps.td" // Converts arith.constant ops from freezing passes back to tf.Const ops. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index 209c173bcae701..b5fb96396f7ef9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -41,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td index 328736da06c40d..4fa7ef333f67ee 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.td @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Converts tf.Const to arith.constant for statically shaped, non-opaque constants. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc index 2d96d13091c62c..8f550a8e5633e4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc @@ -31,9 +31,8 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" //===----------------------------------------------------------------------===// @@ -46,6 +45,7 @@ namespace { using QuantizationUnit = std::pair; using QuantizationUnits = llvm::SetVector; +using ::mlir::quant::OpSet; // Applies prepare quantization on the model in TF dialect for dynamic range // quantization case. @@ -127,7 +127,7 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { return failure(); } - // 2. Quantize collected ops. It is immediatly quantized by inserting Q-DQ + // 2. Quantize collected ops. It is immediately quantized by inserting Q-DQ // pair for int8. if (!(quantizeOps(rewriter, op, quantizable_ops))) { return failure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc index 3c23c4edf0bb11..3f6960dd861fb6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -31,8 +31,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" //===----------------------------------------------------------------------===// @@ -43,7 +43,8 @@ namespace quant { namespace { -using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using QuantMethod = + ::tensorflow::quantization::QuantizationMethod::PresetMethod; using QuantizationUnit = std::pair; using QuantizationUnits = llvm::SetVector; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td index 328736da06c40d..4fa7ef333f67ee 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Converts tf.Const to arith.constant for statically shaped, non-opaque constants. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc index 4e69d48eed69c1..8570652b4019e7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc @@ -32,8 +32,8 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index bf93774e67f73d..56c43988e42e4a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -46,8 +46,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/framework/types.pb.h" @@ -315,7 +315,7 @@ class QuantizeSameScaleOpsPattern } private: - // Checks whether the operation is connnected with a composite function. + // Checks whether the operation is connected with a composite function. // If not, the same-scale op will not be quantized. This decision is based // on the current assumption that the performance gain of the same-scale // op itself could not beat the overhead of the quantize and dequantize diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 1945b69f36f71c..e4eecf204e85c9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -50,7 +50,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td index 1d2bee74d9b4a4..23722a510ac987 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.td @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Converts reamaining arith.constant ops from quantization passes back to diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc index e666fae001024b..2cd7949be7f60c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc @@ -44,8 +44,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index a206a719c26599..374d687428ee3e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -36,8 +36,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td index b6810be4d846d9..ccd477c310e27c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td @@ -16,8 +16,8 @@ include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" def CreateXLAConvOpFromTFConv2DOp : NativeCodeCall< "CreateXlaConvOpFromTfConv2dOp($_builder, $_loc, $0...)">; @@ -216,7 +216,7 @@ def ConvertTFMatMulToXLADotV2Op : Pat< (TF_MatMulOp:$matmul (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), (TF_CastOp (TF_IdentityOp $weight), $truncate1), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, $input_zp, /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $matmul, @@ -235,7 +235,7 @@ def ConvertTFMatMulToXLADotV2OpDynamicRange : Pat< (TF_MatMulOp:$matmul (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), (TF_CastOp (TF_IdentityOp $weight), $truncate1), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -252,7 +252,7 @@ def ConvertTFMatMulToXLADotV2OpWeightOnly : Pat< (TF_MatMulOp:$matmul $input, (TF_MulOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $scale), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (TF_MulOp (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -270,7 +270,7 @@ def ConvertTFMatMulWithNoZeroPointToXLADotV2Op : Pat< (TF_MatMulOp:$matmul (TF_CastOp $input, $truncate), (TF_CastOp (TF_IdentityOp $weight), $truncate1), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -288,7 +288,7 @@ def ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op : Pat< (TF_MatMulOp:$matmul (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), (TF_SubOp (TF_CastOp $weight, $truncate2), $weight_zp), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, $input_zp, $weight_zp, $matmul, $transpose_a, $transpose_b), [(IsInt8ElementType $input), @@ -306,7 +306,7 @@ def ConvertTFMatMulWithTwoInputTensorsAndNoInputZeroPointToXLADotV2Op : Pat< (TF_MatMulOp:$matmul (TF_CastOp $input, $truncate), (TF_SubOp (TF_CastOp $weight, $truncate2), $weight_zp), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $weight_zp, $matmul, $transpose_a, $transpose_b), @@ -324,7 +324,7 @@ def ConvertTFMatMulWithTwoInputTensorsAndNoWeightZeroPointToXLADotV2Op : Pat< (TF_MatMulOp:$matmul (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), (TF_CastOp $weight, $truncate1), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, $input_zp, /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -343,7 +343,7 @@ def ConvertTFMatMulWithTwoInputTensorsAndNoBothZeroPointsToXLADotV2Op : Pat< (TF_MatMulOp:$matmul (TF_CastOp $input, $truncate), (TF_CastOp $weight, $truncate1), - $transpose_a, $transpose_b), + $transpose_a, $transpose_b, $grad_a, $grad_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -407,7 +407,7 @@ def ConvertTFBatchMatMulToXLADotV2Op : Pat< (TF_BatchMatMulV2Op:$batch_matmul (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), (TF_CastOp (TF_IdentityOp $weight), $truncate1), - $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), (CreateXlaDotV2OpFromTfBatchMatMulOp $input, $weight, $input_zp, /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -426,7 +426,7 @@ def ConvertTFBatchMatMulWithNoZeroPointToXLADotV2Op : Pat< (TF_BatchMatMulV2Op:$batch_matmul (TF_CastOp $input, $truncate), (TF_CastOp (TF_IdentityOp $weight), $truncate1), - $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), (CreateXlaDotV2OpFromTfBatchMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -447,7 +447,7 @@ def ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2Op : Pat< (TF_BatchMatMulV2Op:$batch_matmul (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), (TF_SubOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $weight_zp), - $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), (CreateXlaDotV2OpFromTfBatchMatMulOp $input, $weight, $input_zp, $weight_zp, $batch_matmul, $adj_x, $adj_y), [(IsInt8ElementType $input), @@ -465,7 +465,7 @@ def ConvertTFBatchMatMulWithTwoInputTensorsAndNoInputZeroPointToXLADotV2Op : Pat (TF_BatchMatMulV2Op:$batch_matmul (TF_CastOp $input, $truncate), (TF_SubOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $weight_zp), - $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), (CreateXlaDotV2OpFromTfBatchMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $weight_zp, $batch_matmul, $adj_x, $adj_y), @@ -483,7 +483,7 @@ def ConvertTFBatchMatMulWithTwoInputTensorsAndNoWeightZeroPointToXLADotV2Op : Pa (TF_BatchMatMulV2Op:$batch_matmul (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), (TF_CastOp $weight, $truncate2), - $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), (CreateXlaDotV2OpFromTfBatchMatMulOp $input, $weight, $input_zp, /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), @@ -502,7 +502,7 @@ def ConvertTFBatchMatMulWithTwoInputTensorsAndNoBothZeroPointsToXLADotV2Op : Pat (TF_BatchMatMulV2Op:$batch_matmul (TF_CastOp $input, $truncate1), (TF_CastOp $weight, $truncate2), - $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), (CreateXlaDotV2OpFromTfBatchMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc index 5020550ca65a7c..87d230fb16bbde 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -40,7 +41,8 @@ int main(int argc, char **argv) { mlir::arith::ArithDialect, mlir::tf_type::TFTypeDialect, mlir::quant::QuantizationDialect, mlir::quantfork::QuantizationForkDialect, - mlir::tf_executor::TensorFlowExecutorDialect>(); + mlir::tf_executor::TensorFlowExecutorDialect, + mlir::stablehlo::StablehloDialect>(); mlir::func::registerAllExtensions(registry); return failed( mlir::MlirOptMain(argc, argv, "TF quant Pass Driver\n", registry)); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 31a016b374675c..808e0b36af2d24 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -13,6 +13,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", "//tensorflow/python:__subpackages__", ], @@ -33,7 +34,12 @@ cc_library( "//tensorflow/python:__pkg__", ], deps = [ + ":unfreeze_constants", "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:export", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:precalibration", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -42,15 +48,12 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", # Required for CustomAggregator op registration. "//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", - "//tensorflow/compiler/mlir/quantization/tensorflow/cc:save_variables", - "//tensorflow/compiler/mlir/quantization/tensorflow/cc:status_macro", "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:export_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", @@ -59,6 +62,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -71,11 +75,10 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", - "@local_tsl//tsl/platform:env", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@stablehlo//:stablehlo_ops", ], ) @@ -100,11 +103,28 @@ cc_library( pytype_strict_library( name = "py_function_lib_py", srcs = ["py_function_lib.py"], - visibility = ["//visibility:private"], deps = [ ":pywrap_function_lib", + ":representative_dataset", + ":save_model", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_py", - "@pypi_typing_extensions//:pkg", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_algorithm", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:pywrap_calibration", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/client:session", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:wrap_function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_conversion", + "//tensorflow/python/lib/io:file_io", + "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/trackable:autotrackable", + "//tensorflow/python/types:core", + "//third_party/py/numpy", + "@absl_py//absl/logging", ], ) @@ -130,13 +150,14 @@ cc_library( "-use_header_modules", "-parse_headers", ], - visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", # build_cleaner: keep; Required for pybind11. "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:protobuf", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", ], @@ -145,9 +166,38 @@ cc_library( cc_library( name = "py_function_lib", hdrs = ["py_function_lib.h"], - visibility = ["//visibility:private"], + compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:min_max_value", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@pybind11", + ], +) + +cc_library( + name = "unfreeze_constants", + srcs = ["unfreeze_constants.cc"], + hdrs = ["unfreeze_constants.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:save_variables", + "//tensorflow/core:lib", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -159,8 +209,14 @@ tf_python_pybind_extension( deps = [ ":py_function_lib", ":type_casters", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:min_max_value", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", - "//tensorflow/python/lib/core:pybind11_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", "@pybind11", ], ) @@ -175,15 +231,19 @@ tf_python_pybind_extension( deps = [ ":py_function_lib", ":type_casters", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:assign_ids", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:statistics", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", - "//tensorflow/python/lib/core:pybind11_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:import_status_module", @@ -198,6 +258,7 @@ tf_py_strict_test( "pywrap_quantize_model_test.py", ], deps = [ + ":py_function_lib_py", ":pywrap_quantize_model", "//tensorflow:tensorflow_py", "//tensorflow/python/platform:client_testlib", @@ -242,6 +303,7 @@ pytype_strict_library( "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_algorithm", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:pywrap_calibration", "//tensorflow/core:protos_all_py", "//tensorflow/python/client:session", "//tensorflow/python/eager:context", @@ -256,7 +318,6 @@ pytype_strict_library( "//tensorflow/python/trackable:autotrackable", "//tensorflow/python/types:core", "//tensorflow/python/util:tf_export", - "//third_party/py/numpy", "@absl_py//absl/logging", ], ) @@ -378,6 +439,7 @@ pytype_strict_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python/client:session", "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:context", @@ -386,6 +448,7 @@ pytype_strict_library( "//tensorflow/python/platform:tf_logging", "//tensorflow/python/types:core", "//tensorflow/python/util:tf_export", + "//third_party/py/numpy", ], ) @@ -394,7 +457,9 @@ tf_py_strict_test( srcs = ["representative_dataset_test.py"], deps = [ ":representative_dataset", + "//tensorflow/core:protos_all_py", "//tensorflow/python/client:session", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 4a1fd7148b0c15..02a3b703b9db18 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -2518,68 +2518,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: else: self.assertAllClose(new_outputs, expected_outputs, atol=0.13) - @test_util.run_in_graph_and_eager_modes - def test_matmul_ptq_model_stablehlo(self): - activation_fn = None - has_bias = False - batch_sizes = ([], []) - target_opset = quant_opts_pb2.STABLEHLO - - lhs_batch_size, rhs_batch_size = batch_sizes - input_shape = (*lhs_batch_size, 1, 1024) - filter_shape = (*rhs_batch_size, 1024, 3) - static_input_shape = [dim if dim is not None else 2 for dim in input_shape] - model = self._create_matmul_model( - input_shape, - filter_shape, - self._input_saved_model_path, - has_bias, - activation_fn, - ) - rng = np.random.default_rng(seed=1234) - - input_data = ops.convert_to_tensor( - rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( - np.float32 - ) - ) - expected_outputs = model.matmul(input_data) - - def data_gen() -> repr_dataset.RepresentativeDataset: - for _ in range(100): - yield { - 'input_tensor': rng.uniform( - low=0.0, high=1.0, size=static_input_shape - ).astype(np.float32) - } - - quantization_options = quant_opts_pb2.QuantizationOptions( - quantization_method=quant_opts_pb2.QuantizationMethod( - preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 - ), - tags={tag_constants.SERVING}, - signature_keys=['serving_default'], - op_set=target_opset, - ) - converted_model = quantize_model.quantize( - self._input_saved_model_path, - self._output_saved_model_path, - quantization_options, - representative_dataset=data_gen(), - ) - - self.assertIsNotNone(converted_model) - self.assertCountEqual( - converted_model.signatures._signatures.keys(), {'serving_default'} - ) - - new_outputs = converted_model.signatures['serving_default']( - input_tensor=ops.convert_to_tensor(input_data) - ) - # Tests that the quantized graph outputs similar values. The rtol value is - # arbitrary. - self.assertAllClose(new_outputs, expected_outputs, rtol=0.02) - @parameterized.named_parameters( { 'testcase_name': 'with_biasadd', @@ -2972,13 +2910,17 @@ def test_matmul_ptq_model_with_unfreeze_constants(self): ) @parameterized.named_parameters( - ('use_constant_with_int32_input', dtypes.int32, False), - ('use_variable_with_int32_input', dtypes.int32, True), - ('use_constant_with_int64_input', dtypes.int64, False), - ('use_variable_with_int64_input', dtypes.int64, True), + ('use_constant_with_int32_input', dtypes.int32, False, True), + ('use_variable_with_int32_input', dtypes.int32, True, True), + ('use_constant_with_int64_input', dtypes.int64, False, True), + ('use_variable_with_int64_input', dtypes.int64, True, True), + ('small_gather_use_constant', dtypes.int32, False, False), + ('small_gather_use_variable', dtypes.int32, True, False), ) @test_util.run_v2_only - def test_gather_model(self, input_type, use_variable): + def test_gather_model( + self, input_type, use_variable, expect_quantized_gather + ): model = self._create_gather_model(input_type, use_variable) saved_model_save.save(model, self._input_saved_model_path) @@ -2991,7 +2933,9 @@ def test_gather_model(self, input_type, use_variable): ), tags=tags, signature_keys=['serving_default'], - op_set=quant_opts_pb2.TF, + op_set=quant_opts_pb2.XLA, + # Gather op is opt-outed if the size is smaller than the threshold. + min_num_elements_for_weights=1024 if expect_quantized_gather else 8192, ) data_gen = self._create_data_generator( @@ -3014,11 +2958,14 @@ def test_gather_model(self, input_type, use_variable): converted_model.signatures._signatures.keys(), {'serving_default'} ) - output_loader = saved_model_loader.SavedModelLoader( - self._output_saved_model_path - ) - output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def - self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + if expect_quantized_gather: + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + ) + else: + self.assertSizeRatioGreaterThan( + self._output_saved_model_path, self._input_saved_model_path, 2 / 3 + ) @test_util.run_in_graph_and_eager_modes def test_model_ptq_use_representative_samples_list(self): @@ -3366,7 +3313,7 @@ def test_model_ptq_use_tf_dataset_for_representative_dataset(self): self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes - def test_model_ptq_no_representative_sample_shows_warnings(self): + def test_model_ptq_no_representative_sample_not_quantized(self): self._create_matmul_model( input_shape=(1, 1024), weight_shape=(1024, 3), @@ -3382,30 +3329,14 @@ def test_model_ptq_no_representative_sample_shows_warnings(self): signature_keys=['serving_default'], ) - with self.assertLogs(level='WARN') as warning_logs: - # Save the logger verbosity. - prev_log_level = logging.get_verbosity() - logging.set_verbosity(logging.WARN) - - try: - converted_model = quantize_model.quantize( - self._input_saved_model_path, - self._output_saved_model_path, - quantization_options, - # Put no sample into the representative dataset to make calibration - # impossible. - representative_dataset=[], - ) - finally: - # Restore the logger verbosity. - logging.set_verbosity(prev_log_level) - - self.assertNotEmpty(warning_logs.records) - self.assertTrue( - self._any_log_contains( - 'does not have min or max values', warning_logs.records - ) - ) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + self._output_saved_model_path, + quantization_options, + # Put no sample into the representative dataset to make calibration + # impossible. + representative_dataset=[], + ) self.assertIsNotNone(converted_model) self.assertCountEqual( @@ -3486,36 +3417,12 @@ def data_gen() -> repr_dataset.RepresentativeDataset: op_set=quant_opts_pb2.TF, ) - with self.assertLogs(level='WARN') as warning_logs: - # Save the logger verbosity. - log_level = logging.get_verbosity() - logging.set_verbosity(logging.WARN) - - try: - converted_model = quantize_model.quantize( - self._input_saved_model_path, - self._output_saved_model_path, - quantization_options, - representative_dataset=data_gen(), - ) - finally: - # Restore the logger verbosity. - logging.set_verbosity(log_level) - - self.assertNotEmpty(warning_logs.records) - - # Warning message should contain the function name. The uncalibrated path - # is when the condition is true, so 'cond_true' function must be part of - # the warning message. - self.assertTrue(self._any_log_contains('cond_true', warning_logs.records)) - self.assertFalse( - self._any_log_contains('cond_false', warning_logs.records) - ) - self.assertTrue( - self._any_log_contains( - 'does not have min or max values', warning_logs.records - ) - ) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen(), + ) self.assertIsNotNone(converted_model) self.assertCountEqual( @@ -3527,6 +3434,25 @@ def data_gen() -> repr_dataset.RepresentativeDataset: output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + # Tests that the false branch contains a quantized function call whereas the + # true branch doesn't. + def _is_quantized_function_call_node( + node_def: node_def_pb2.NodeDef, + ) -> bool: + return node_def.op == 'PartitionedCall' and node_def.attr[ + 'f' + ].func.name.startswith('quantized_') + + for func in output_graphdef.library.function: + if func.signature.name.startswith('cond_false'): + self.assertTrue( + any(map(_is_quantized_function_call_node, func.node_def)) + ) + elif func.signature.name.startswith('cond_true'): + self.assertFalse( + any(map(_is_quantized_function_call_node, func.node_def)) + ) + # Run this test only with the eager mode. @test_util.run_v2_only def test_ptq_model_with_multiple_signatures(self): diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h index 4d120f29491293..dbb557f2b5b033 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h @@ -15,7 +15,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "pybind11/pytypes.h" // from @pybind11 +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow::quantization { @@ -27,12 +38,71 @@ class PyFunctionLibrary { public: virtual ~PyFunctionLibrary() = default; - // Assigns UUIDs to each CustomAggregator op found in each GraphDef in - // `exported_model`. The UUIDs are set to the `id` attributes. The UUIDs will - // be used during calibration step to identify the collected quantization - // statistics for each CustsomAggregator op. - virtual ExportedModel AssignIdsToCustomAggregatorOps( - const ExportedModel& exported_model) const = 0; + // Saves `exported_model` to `dst_saved_model_path` as SavedModel. + // `src_saved_model_path` is the path to the source SavedModel from which the + // exported model is produced. It is used to copy the asset files to + // `dst_saved_model_path`. `tags` will be attached to the saved + // `MetaGraphDef`. `signature_def_map` will be passed to the + // `add_meta_graph_and_variables` function, which is internally used to add a + // `MetaGraphDef` to save to the SavedModel. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange + virtual void SaveExportedModel( + absl::string_view dst_saved_model_path, + const ExportedModel& exported_model, + absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& + signature_def_map) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:save_exported_model, + // py_function_lib.py:save_exported_model, + // ) + + // Runs calibration on a model saved at `saved_model_path`. `exported_model` + // should be the corresponding exported model resulting from the + // pre-calibration step. `signature_keys` is a set of keys that identify a + // SignatureDef to run the calibration on. `tags` is a set of strings that + // identify the `MetaGraphDef`. `calibration_options` provides configurations + // for the calibration behavior. `representative_dataset` is a python object + // of type `RepresentativeDatasetOrMapping`, which is used to run the + // calibration. + // + // Returns the updated exported model where the collected calibration + // statistics are added to `CustomAggregator` nodes at the `min` and `max` + // attributes. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange(run_calibration) + virtual void RunCalibration( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const CalibrationOptions& calibration_options, + bool force_graph_mode_calibration, + pybind11::object representative_dataset) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:run_calibration, + // py_function_lib.py:run_calibration, + // ) + + // Retrieves min and max value from `calibration_statistics`, based on the + // calibration method specified by `calibration_options`. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange(get_calibration_min_max_value) + virtual stablehlo::quantization::MinMaxValue GetCalibrationMinMaxValue( + const tensorflow::calibrator::CalibrationStatistics& + calibration_statistics, + const CalibrationOptions& calibration_options) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:get_calibration_min_max_value, + // py_function_lib.py:get_calibration_min_max_value, + // ) }; } // namespace tensorflow::quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index 145149e5341042..22c3be3d6034e7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -13,10 +13,516 @@ # limitations under the License. # ============================================================================== """Defines a wrapper class for overridden python method definitions.""" -import uuid +from collections.abc import Callable, Collection, Mapping, Sequence +from typing import Optional + +from absl import logging from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd +from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.eager import wrap_function +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_conversion +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.trackable import autotrackable +from tensorflow.python.types import core + +# Name of the saved model assets directory. +_ASSETS_DIR = 'assets' +_ASSETS_EXTRA_DIR = 'assets.extra' + + +def _get_saver_def_or_none( + exported_model: exported_model_pb2.ExportedModel, +) -> Optional[saver_pb2.SaverDef]: + """Returns the SaverDef from ExportedModel, None otherwise. + + Args: + exported_model: ExportedModel to take the SaverDef from. + + Returns: + SaverDef instance if the field `saver_def` is set. None otherwise. + """ + if exported_model.HasField('saver_def'): + return exported_model.saver_def + return None + + +def _copy_assets(src_path: str, dst_path: str) -> None: + """Copies the assets directory of the saved model. + + Clones the contents of the assets/ directory from the source saved model + directory to the destination saved model directory. Nothing will be copied if + there are no assets directory in the source directory. + + Args: + src_path: Source saved model directory. + dst_path: Destination saved model directory. This directory must exist. + """ + for assets_dir_name in [_ASSETS_DIR, _ASSETS_EXTRA_DIR]: + src_assets_path = file_io.join(src_path, assets_dir_name) + if not file_io.file_exists_v2(src_assets_path): + # Do nothing if the source assets path does not exist. + continue + + dst_assets_path = file_io.join(dst_path, assets_dir_name) + file_io.create_dir_v2(dst_assets_path) + + for curr_dir, _, files in file_io.walk_v2(src_assets_path): + for asset_file_name in files: + src_asset_file = file_io.join(curr_dir, asset_file_name) + + # Construct the destination assets file path. + curr_dst_dir = curr_dir.replace(src_assets_path, dst_assets_path) + dst_asset_file = file_io.join(curr_dst_dir, asset_file_name) + + file_io.copy_v2(src_asset_file, dst_asset_file) + logging.info( + 'Copied asset file: %s -> %s', src_asset_file, dst_asset_file + ) + + +def _validate_representative_dataset( + representative_dataset: rd.RepresentativeDatasetOrMapping, + signature_keys: Collection[str], +) -> None: + """Validates the representative dataset, based on the signature keys. + + Representative dataset can be provided in two different forms: a single + instance of `RepresentativeDataset` or a map of signature key to the + corresponding `RepresentativeDataset`. These have a relationship with + `signature_keys`. + + This function validates the following conditions: + * If `len(signature_keys) > 1`, then `representative_dataset` should be a + mapping where the keys exactly match the elements in `signature_keys`. + * If `len(signature_keys) == 1`, then both a mapping and a single instance of + `RepresentativeDataset` are allowed. + * This function also assumes `len(signature_keys) > 0`. + + Args: + representative_dataset: A `RepresentativeDataset` or a map of string to + `RepresentativeDataset` to be validated. + signature_keys: A collection of strings that contains the signature keys, + each identifying a `SignatureDef`. + + Raises: + ValueError: Iff `representative_dataset` does not satisfy the conditions + above. + """ + if isinstance(representative_dataset, Mapping): + if set(signature_keys) != set(representative_dataset.keys()): + raise ValueError( + 'The signature keys and the keys of representative dataset map ' + f'do not match. Signature keys: {set(signature_keys)}, ' + f'representative dataset map: {set(representative_dataset.keys())}.' + ) + else: + if len(signature_keys) > 1: + raise ValueError( + 'Representative dataset is not a mapping ' + f'(got: {type(representative_dataset)}), ' + 'but there is more than one signature key provided. ' + 'Please provide a map of {signature_key -> dataset} ' + 'with more than one signature key.' + ) + + +def _replace_tensors_by_numpy_ndarrays( + repr_ds_map: rd.RepresentativeDatasetMapping, +) -> None: + """Replaces tf.Tensors by their evaluated numpy arrays. + + This assumes that tf.Tensors in representative samples are created in the + default Graph. It will raise an error if tensors are created in a different + graph. + + Args: + repr_ds_map: SignatureDef key -> RepresentativeDataset mapping. + """ + with session.Session() as sess: + for signature_def_key in repr_ds_map: + # Replaces the dataset with a new dataset where tf.Tensors are replaced + # by their evaluated values. + ds = repr_ds_map[signature_def_key] + repr_ds_map[signature_def_key] = rd.replace_tensors_by_numpy_ndarrays( + ds, sess + ) + + +def _create_sample_validator( + expected_input_keys: Collection[str], +) -> Callable[[rd.RepresentativeSample], rd.RepresentativeSample]: + """Creates a validator function for a representative sample. + + Args: + expected_input_keys: Input keys (keyword argument names) that the function + the sample will be used for is expecting to receive. + + Returns: + A callable that validates a `RepresentativeSample`. + """ + + def validator( + sample: rd.RepresentativeSample, + ) -> rd.RepresentativeSample: + """Validates a single instance of representative sample. + + This provides a simple check for `sample` that this is a mapping of + {input_key: input_value}. + + Args: + sample: A `RepresentativeSample` to validate. + + Returns: + `sample` iff it is valid. + + Raises: + ValueError: iff the sample isn't an instance of `Mapping`. + KeyError: iff the sample does not have the set of input keys that match + the input keys of the function. + """ + if not isinstance(sample, Mapping): + raise ValueError( + 'Invalid representative sample type. Provide a mapping ' + '(usually a dict) of {input_key: input_value}. ' + f'Got type: {type(sample)} instead.' + ) + + if set(sample.keys()) != expected_input_keys: + raise KeyError( + 'Invalid input keys for representative sample. The function expects ' + f'input keys of: {set(expected_input_keys)}. ' + f'Got: {set(sample.keys())}. Please provide correct input keys for ' + 'representative samples.' + ) + + return sample + + return validator + + +# TODO(b/249918070): Implement a progress bar. +def _log_sample_num_for_calibration( + representative_dataset: rd.RepresentativeDataset, +) -> rd.RepresentativeDataset: + """Logs the sample number for calibration. + + If in debug logging level, the "sample number / total num samples" is logged + for every 5 iterations. + + This is often useful when tracking the progress of the calibration step which + is often slow and may look stale if there's no logs being printed. + + Args: + representative_dataset: The representative dataset. + + Yields: + The representative samples from `representative_dataset` without any + modification. + """ + num_samples: Optional[int] = rd.get_num_samples(representative_dataset) + if num_samples is None: + total_num_samples = '?' + logging.info('Representative dataset size unknown.') + else: + total_num_samples = str(num_samples) + logging.info('Using representative dataset of size: %s', total_num_samples) + + sample_num = 0 + for sample in representative_dataset: + sample_num += 1 + + # Log the sample number for every 5 iterations. + logging.log_every_n( + logging.DEBUG, + 'Running representative sample for calibration: %d / %s', + 5, + sample_num, + total_num_samples, + ) + yield sample + + logging.info( + 'Running representative samples complete: %d / %s', + sample_num, + total_num_samples, + ) + + +def _run_function_for_calibration_graph_mode( + sess: session.Session, + signature_def: meta_graph_pb2.SignatureDef, + representative_dataset: rd.RepresentativeDataset, +) -> None: + """Runs the representative dataset through a function for calibration. + + NOTE: This is intended to be run in graph mode (TF1). + + The function is identified by the SignatureDef. + + Args: + sess: The Session object to run the function in. + signature_def: A SignatureDef that identifies a function by specifying the + inputs and outputs. + representative_dataset: The representative dataset to run through the + function. + """ + output_tensor_names = [ + output_tensor_info.name + for output_tensor_info in signature_def.outputs.values() + ] + + sample_validator = _create_sample_validator( + expected_input_keys=signature_def.inputs.keys() + ) + + for sample in map( + sample_validator, _log_sample_num_for_calibration(representative_dataset) + ): + # Create a mapping from input tensor name to the input tensor value. + # ex) "Placeholder:0" -> [0, 1, 2] + feed_dict = rd.create_feed_dict_from_input_data(sample, signature_def) + sess.run(output_tensor_names, feed_dict=feed_dict) + + +def _run_graph_for_calibration_graph_mode( + model_dir: str, + tags: Collection[str], + representative_dataset_map: rd.RepresentativeDatasetMapping, +) -> None: + """Runs the graph for calibration in graph mode. + + This function assumes _graph mode_ (used when legacy TF1 is used or when eager + mode is explicitly disabled) when running the graph. This step is used in + order to collect the statistics in CustomAggregatorOp for quantization using + the representative dataset for the actual data provided for inference. + + Args: + model_dir: Path to SavedModel directory. + tags: Collection of tags identifying the MetaGraphDef within the SavedModel. + representative_dataset_map: A map where signature keys are mapped to + corresponding representative datasets. + + Raises: + ValueError: When running the function with the representative dataset fails. + """ + # Replace tf.Tensors by numpy ndarrays in order to reuse the samples in a + # different graph when running the calibration. + _replace_tensors_by_numpy_ndarrays(representative_dataset_map) + + # Run the calibration in a new graph to avoid name collision, which could + # happen when the same model is loaded multiple times in the default graph. + with ops.Graph().as_default(), session.Session() as sess: + meta_graph: meta_graph_pb2.MetaGraphDef = loader_impl.load( + sess, tags, export_dir=model_dir + ) + + for signature_key, repr_ds in representative_dataset_map.items(): + sig_def = meta_graph.signature_def[signature_key] + + try: + _run_function_for_calibration_graph_mode( + sess, signature_def=sig_def, representative_dataset=repr_ds + ) + except Exception as ex: + raise ValueError( + 'Failed to run representative dataset through the ' + f'function with the signature key: {signature_key}.' + ) from ex + + +def _convert_values_to_tf_tensors( + sample: rd.RepresentativeSample, +) -> Mapping[str, core.Tensor]: + """Converts TensorLike values of `sample` to Tensors. + + Creates a copy of `sample`, where each value is converted to Tensors + unless it is already a Tensor. + The values are not converted in-place (i.e. `sample` is not mutated). + + Args: + sample: A representative sample, which is a map of {name -> tensorlike + value}. + + Returns: + Converted map of {name -> tensor}. + """ + tensor_mapping = {} + for name, tensorlike_value in sample.items(): + if isinstance(tensorlike_value, core.Tensor): + tensor_value = tensorlike_value + else: + tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch( + tensorlike_value + ) + + tensor_mapping[name] = tensor_value + + return tensor_mapping + + +def _run_function_for_calibration_eager_mode( + func: wrap_function.WrappedFunction, + representative_dataset: rd.RepresentativeDataset, +) -> None: + """Runs the representative dataset through a function for calibration. + + NOTE: This is intended to be run in eager mode (TF2). + + Args: + func: The function to run the representative samples through. + representative_dataset: Representative dataset used for calibration. The + input keys and input values of the representative samples should match the + keyword arguments of `func`. + """ + _, keyword_args = func.structured_input_signature + sample_validator = _create_sample_validator( + expected_input_keys=keyword_args.keys() + ) + + for sample in map( + sample_validator, _log_sample_num_for_calibration(representative_dataset) + ): + # Convert any non-Tensor values from the sample to Tensors. + # This conversion is required because the model saved in `model_dir` is + # saved using TF1 SavedModelBuilder, which doesn't save the + # SavedObjectGraph. + # TODO(b/236795224): Remove the need for this conversion by keeping the + # FunctionSpec (object graph) in the SavedModel. Related: b/213406917. + func_kwargs = _convert_values_to_tf_tensors(sample) + func(**func_kwargs) + + +def _run_graph_for_calibration_eager_mode( + model_dir: str, + tags: Collection[str], + representative_dataset_map: rd.RepresentativeDatasetMapping, +) -> None: + """Runs the graph for calibration in eager mode. + + This function assumes _eager mode_ (enabled in TF2 by default) when running + the graph. This step is used in order to collect the statistics in + CustomAggregatorOp for quantization using the representative dataset for the + actual data provided for inference. + + Args: + model_dir: Path to SavedModel directory. + tags: Collection of tags identifying the MetaGraphDef within the SavedModel. + representative_dataset_map: A map where signature keys are mapped to + corresponding representative datasets. + + Raises: + ValueError: When running the function with the representative dataset fails. + """ + root: autotrackable.AutoTrackable = load.load(model_dir, tags) + for signature_key, repr_ds in representative_dataset_map.items(): + try: + _run_function_for_calibration_eager_mode( + func=root.signatures[signature_key], representative_dataset=repr_ds + ) + except Exception as ex: + raise ValueError( + 'Failed to run representative dataset through the ' + f'function with the signature key: {signature_key}.' + ) from ex + + +def _run_graph_for_calibration( + float_model_dir: str, + signature_keys: Sequence[str], + tags: Collection[str], + representative_dataset: rd.RepresentativeDatasetOrMapping, + force_graph_mode_calibration: bool, +) -> None: + """Runs the graph for calibration using representative datasets. + + Args: + float_model_dir: Path to the model to calibrate. + signature_keys: Sequence of keys identifying SignatureDef containing inputs + and outputs. + tags: Collection of tags identifying the MetaGraphDef within the SavedModel + to analyze. + representative_dataset: An iterator that returns a dictionary of {input_key: + input_value} or a mapping from signature keys to such iterators. When + `signature_keys` contains more than one signature key, + `representative_datsaet` should be a mapping that maps each signature keys + to the corresponding representative dataset. + force_graph_mode_calibration: If set to true, it forces calibration in graph + model instead of eager mode when the context is in eager mode. + + Raises: + ValueError iff: + * The representative dataset format is invalid. + * It fails to run the functions using the representative datasets. + """ + try: + _validate_representative_dataset(representative_dataset, signature_keys) + except Exception as ex: + raise ValueError('Invalid representative dataset.') from ex + + # If `representative_dataset` is not a mapping, convert to a mapping for the + # following functions to handle representative datasets more conveniently. + representative_dataset_map = representative_dataset + if not isinstance(representative_dataset, Mapping): + # `signature_keys` is guaranteed to have only one element after the + # validation. + representative_dataset_map = {signature_keys[0]: representative_dataset} + + try: + if context.executing_eagerly() and not force_graph_mode_calibration: + logging.info('Calibration step is executed in eager mode.') + _run_graph_for_calibration_eager_mode( + float_model_dir, tags, representative_dataset_map + ) + else: + logging.info('Calibration step is executed in graph mode.') + _run_graph_for_calibration_graph_mode( + float_model_dir, tags, representative_dataset_map + ) + except Exception as ex: + raise ValueError( + 'Failed to run graph for post-training quantization calibration.' + ) from ex + + logging.info('Calibration step complete.') + + +def _get_min_max_from_calibrator( + node_id: bytes, + calib_opts: quantization_options_pb2.CalibrationOptions, +) -> tuple[float, float]: + """Calculate min and max from statistics using calibration options. + + Args: + node_id: bytes of node id. + calib_opts: Calibration options used for calculating min and max. + + Returns: + (min_value, max_value): Min and max calculated using calib_opts. + + Raises: + ValueError: Unsupported calibration method is given. + """ + statistics: calibration_statistics_pb2.CalibrationStatistics = ( + pywrap_calibration.get_statistics_from_calibrator(node_id) + ) + min_value, max_value = calibration_algorithm.get_min_max_value( + statistics, calib_opts + ) + return min_value, max_value class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary): @@ -26,27 +532,117 @@ class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary): declared in `pywrap_function_lib.PyFunctionLibrary`. """ - def assign_ids_to_custom_aggregator_ops( + # LINT.IfChange(save_exported_model) + def save_exported_model( self, + dst_saved_model_path: str, exported_model_serialized: bytes, - ) -> bytes: - """Assigns UUIDs to each CustomAggregator op find in the graph def. + src_saved_model_path: str, + tags: set[str], + serialized_signature_def_map: dict[str, bytes], + ) -> None: + # LINT.ThenChange(py_function_lib.h:save_exported_model) + """Saves `ExportedModel` to `dst_saved_model_path` as a SavedModel. Args: - exported_model_serialized: Serialized `ExportedModel` instance. - - Returns: - Serialized `ExportedModel` whose CustomAggregator ops are assigned UUIDs - to their `id` attributes. + dst_saved_model_path: Destination path to save the exported model. + exported_model_serialized: Exported model to export as SavedModel. + src_saved_model_path: Path to the source SavedModel. This will be used to + copy the asset files to `dst_saved_model_path`. + tags: Tags to attach to the saved MetaGraphDef. + serialized_signature_def_map: Signature key -> serialized SignatureDef. """ exported_model = exported_model_pb2.ExportedModel.FromString( exported_model_serialized ) - graph_def = exported_model.graph_def - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op == 'CustomAggregator': - node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii') + # Deserialize values in signature_def_map. + signature_def_map = {} + for key, serialized_signature_def in serialized_signature_def_map.items(): + signature_def_map[key] = meta_graph_pb2.SignatureDef.FromString( + serialized_signature_def + ) + + save_model.save_model_v1( + exported_model.graph_def, + dst_saved_model_path, + signature_def_map, + tags, + init_op_name=exported_model.init_node_name, + saver_def=_get_saver_def_or_none(exported_model), + checkpoint_dir=exported_model.checkpoint_dir, + function_aliases=exported_model.function_aliases, + asset_file_defs=exported_model.asset_file_defs, + ) + + _copy_assets(src_saved_model_path, dst_saved_model_path) - return exported_model.SerializeToString() + # TODO: b/311097139 - Extract calibration related functions into a separate + # file. + # LINT.IfChange(run_calibration) + def run_calibration( + self, + saved_model_path: str, + signature_keys: list[str], + tags: set[str], + calibration_options_serialized: bytes, + force_graph_mode_calibration: bool, + representative_dataset: rd.RepresentativeDatasetOrMapping, + ) -> None: + # LINT.ThenChange(py_function_lib.h:run_calibration) + """Runs calibration and adds calibration statistics to exported model. + + Args: + saved_model_path: Path to the SavedModel to run calibration. + signature_keys: List of signature keys corresponding to SignatureDefs to + run calibration on. + tags: A set of tags that identify the MetaGraphDef. + calibration_options_serialized: Serialized `CalibrationOptions`. + force_graph_mode_calibration: If True, runs the calibration in graph mode. + representative_dataset: Representative dataset to run calibration. + + Returns: + Updated exported model (serialized) where the collected calibration + statistics are added to `CustomerAggregator` nodes at the `min` and `max` + attributes. + """ + # Uses the representative dataset to collect statistics for calibration. + # After this operation, min & max values are stored separately in a global + # CalibratorSingleton instance. + _run_graph_for_calibration( + saved_model_path, + signature_keys, + tags, + representative_dataset, + force_graph_mode_calibration, + ) + + # LINT.IfChange(get_calibration_min_max_value) + def get_calibration_min_max_value( + self, + calibration_statistics_serialized: bytes, + calibration_options_serialized: bytes, + ) -> tuple[float, float]: + """Calculates min and max values from statistics. + + Args: + calibration_statistics_serialized: Serialized `CalibrationStatistics`. + This will be the source to calculate min and max values from. + calibration_options_serialized: Serialized `CalibrationOptions`. Specifies + how the min / max should be calculated. + + Returns: + (min_value, max_value): Min and max calculated using calib_opts. + + Raises: + ValueError: Unsupported calibration method is given. + """ + # LINT.ThenChange(py_function_lib.h:get_calibration_min_max_value) + return calibration_algorithm.get_min_max_value( + calibration_statistics_pb2.CalibrationStatistics.FromString( + calibration_statistics_serialized + ), + quantization_options_pb2.CalibrationOptions.FromString( + calibration_options_serialized + ), + ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib_test.py index fbac4dad0454de..b170daca109e98 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib_test.py @@ -13,45 +13,13 @@ # limitations under the License. # ============================================================================== """Tests for py_function_lib.""" -from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib -from tensorflow.core.framework import function_pb2 -from tensorflow.core.framework import node_def_pb2 from tensorflow.python.platform import test class PyFunctionLibTest(test.TestCase): - - def test_assign_ids_to_custom_aggregator_ops(self): - func_lib = py_function_lib.PyFunctionLibrary() - exported_model = exported_model_pb2.ExportedModel() - function_def: function_pb2.FunctionDef = ( - exported_model.graph_def.library.function.add() - ) - - node_def_1: node_def_pb2.NodeDef = function_def.node_def.add() - node_def_1.op = 'CustomAggregator' - - node_def_2: node_def_pb2.NodeDef = function_def.node_def.add() - node_def_2.op = 'Identity' - - result_exported_model = exported_model_pb2.ExportedModel.FromString( - func_lib.assign_ids_to_custom_aggregator_ops( - exported_model.SerializeToString() - ) - ) - result_function_def = result_exported_model.graph_def.library.function[0] - - # Check that a 'CustomAggregatorOp' has an 'id' attribute whereas other ops - # don't. - result_node_def_1 = result_function_def.node_def[0] - self.assertEqual(result_node_def_1.op, 'CustomAggregator') - self.assertIn('id', result_node_def_1.attr) - self.assertLen(result_node_def_1.attr, 1) - - result_node_def_2 = result_function_def.node_def[1] - self.assertEqual(result_node_def_2.op, 'Identity') - self.assertNotIn('id', result_node_def_2.attr) + # Functions in PyFunctionLib is in the process of migration to c++ + # implementations. + pass if __name__ == '__main__': diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index 4b84bca54b71b9..3e14a9bd1e8b73 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -12,15 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" -#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace py = ::pybind11; namespace { +using ::stablehlo::quantization::MinMaxValue; +using ::tensorflow::SignatureDef; +using ::tensorflow::calibrator::CalibrationStatistics; +using ::tensorflow::quantization::CalibrationOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; @@ -33,10 +51,35 @@ class PyFunctionLibraryTrampoline : public PyFunctionLibrary { public: using PyFunctionLibrary::PyFunctionLibrary; - ExportedModel AssignIdsToCustomAggregatorOps( - const ExportedModel& exported_model) const override { - PYBIND11_OVERRIDE_PURE(ExportedModel, PyFunctionLibrary, - assign_ids_to_custom_aggregator_ops, exported_model); + void SaveExportedModel(const absl::string_view dst_saved_model_path, + const ExportedModel& exported_model, + const absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& + signature_def_map) const override { + PYBIND11_OVERRIDE_PURE(void, PyFunctionLibrary, save_exported_model, + dst_saved_model_path, exported_model, + src_saved_model_path, tags, signature_def_map); + } + + void RunCalibration(const absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const CalibrationOptions& calibration_options, + const bool force_graph_mode_calibration, + const py::object representative_dataset) const override { + PYBIND11_OVERRIDE_PURE(void, PyFunctionLibrary, run_calibration, + saved_model_path, signature_keys, tags, + calibration_options, force_graph_mode_calibration, + representative_dataset); + } + + MinMaxValue GetCalibrationMinMaxValue( + const CalibrationStatistics& calibration_statistics, + const CalibrationOptions& calibration_options) const override { + PYBIND11_OVERRIDE_PURE(MinMaxValue, PyFunctionLibrary, + get_calibration_min_max_value, + calibration_statistics, calibration_options); } }; @@ -46,6 +89,18 @@ PYBIND11_MODULE(pywrap_function_lib, m) { py::class_( m, "PyFunctionLibrary") .def(py::init<>()) - .def("assign_ids_to_custom_aggregator_ops", - &PyFunctionLibrary::AssignIdsToCustomAggregatorOps); + .def("save_exported_model", &PyFunctionLibrary::SaveExportedModel, + py::arg("dst_saved_model_path"), + py::arg("exported_model_serialized"), + py::arg("src_saved_model_path"), py::arg("tags"), + py::arg("serialized_signature_def_map")) + .def("run_calibration", &PyFunctionLibrary::RunCalibration, + py::arg("saved_model_path"), py::arg("signature_keys"), + py::arg("tags"), py::arg("calibration_options_serialized"), + py::arg("force_graph_mode_calibration"), + py::arg("representative_dataset")) + .def("get_calibration_min_max_value", + &PyFunctionLibrary::GetCalibrationMinMaxValue, + py::arg("calibration_statistics_serialized"), + py::arg("calibration_options_serialized")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi index 4c1c8937e8d38b..55c7a4fb346a70 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi @@ -12,7 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import Any + class PyFunctionLibrary: - def assign_ids_to_custom_aggregator_ops( - self, exported_model_serialized: bytes - ) -> bytes: ... + + # LINT.IfChange(save_exported_model) + def save_exported_model( + self, + dst_saved_model_path: str, + exported_model_serialized: bytes, + src_saved_model_path: str, + tags: set[str], + serialized_signature_def_map: dict[str, bytes], + ) -> None: ... + # LINT.ThenChange() + + # LINT.IfChange(run_calibration) + def run_calibration( + self, + saved_model_path: str, + signature_keys: list[str], + tags: set[str], + calibration_options_serialized: bytes, + force_graph_mode_calibration: bool, + representative_dataset: Any, + ) -> None: ... + # LINT.ThenChange() + + # LINT.IfChange(get_calibration_min_max_value) + def get_calibration_min_max_value( + self, + calibration_statistics_serialized: bytes, + calibration_options_serialized: bytes, + ) -> tuple[float, float]: ... + # LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 05eb0123589c0a..43eebfd53468d8 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 @@ -30,19 +30,27 @@ limitations under the License. #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" -#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace py = pybind11; namespace { -using ::tensorflow::calibrator::CalibrationStatistics; -using ::tensorflow::calibrator::CalibratorSingleton; +using ::stablehlo::quantization::AddCalibrationStatistics; +using ::stablehlo::quantization::AssignIdsToCustomAggregatorOps; +using ::stablehlo::quantization::EnableDebugging; +using ::stablehlo::quantization::io::CreateTmpDir; +using ::tensorflow::SignatureDef; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::QuantizationOptions; @@ -52,157 +60,283 @@ using ::tensorflow::quantization::QuantizePtqModelPreCalibration; using ::tensorflow::quantization::QuantizeQatModel; using ::tensorflow::quantization::QuantizeWeightOnly; -// Retrieves collected statistics of a `CustomAggregator` node from the -// singleton. `id` is the identifier of the `CustomAggregator`. -CalibrationStatistics GetStatisticsFromCalibrator(const absl::string_view id) { - std::optional statistics = - CalibratorSingleton::GetStatistics(id); - - if (!statistics.has_value()) { - throw py::value_error(absl::StrFormat( - "Calibrated data does not exist. Cannot find statistics." - "value for id: '%s'", - id)); - } - - return *statistics; -} - } // namespace PYBIND11_MODULE(pywrap_quantize_model, m) { // Supports absl::StatusOr type conversions. pybind11::google::ImportStatusModule(); - // TODO - b/308532051: Make protobuf objects work without serialization - // overhead. pybind11_protobuf::ImportNativeProtoCasters(); - // Calibrator related functions. - m.def( - "clear_calibrator", - [] { CalibratorSingleton::ClearCollectedInformation(); }, - R"pbdoc( - Clears the collected metrics from the calibrator. - )pbdoc"); - m.def( - "clear_data_from_calibrator", - [](const absl::string_view id) { CalibratorSingleton::ClearData(id); }, - R"pbdoc( - Clears the collected data of the given id from calibrator. - )pbdoc"); - m.def( - "get_statistics_from_calibrator", - [](const absl::string_view id) -> CalibrationStatistics { - return GetStatisticsFromCalibrator(id); - }, - R"pbdoc( - Returns the proto CalibrationStatistics given id from calibrator. - )pbdoc"); - - // Quantization functions. m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange "quantize_qat_model", - [](const absl::string_view saved_model_path, + [](const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + const QuantizationOptions& quantization_options, const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quant_opts, - const absl::flat_hash_map& function_aliases) - -> absl::StatusOr { - return QuantizeQatModel(saved_model_path, signature_keys, tags, - quant_opts, function_aliases); + const absl::flat_hash_map& + signature_def_map, + const absl::flat_hash_map& function_aliases, + const PyFunctionLibrary& py_function_library) -> absl::Status { + // LINT.ThenChange(pywrap_quantize_model.pyi:quantize_qat_model) + std::unordered_set tags; + tags.insert(quantization_options.tags().begin(), + quantization_options.tags().end()); + const absl::StatusOr exported_model = + QuantizeQatModel(src_saved_model_path, signature_keys, tags, + quantization_options, function_aliases); + if (!exported_model.ok()) return exported_model.status(); + + // Remove the `tpu` tag from the debug quantized saved model as it is + // for CPU. Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + if (quantization_options.has_debugger_options()) { + tags.erase("tpu"); + } + py_function_library.SaveExportedModel( + dst_saved_model_path, *exported_model, src_saved_model_path, tags, + signature_def_map); + + return absl::OkStatus(); }, R"pbdoc( - Returns serialized ExportedModel that contains the quantized model's - GraphDef and metadata. The user should pass a serialized - `QuantizationOptions` for the `quant_opts` argument. + Quantizes a model that went through quantization-aware training (QAT) + saved at `src_saved_model_path`. The resulting model will be saved to + `dst_saved_model_path`. Returns an OK sataus when successful, otherwise + raises `StatusNotOk` exception. - Raises `StatusNotOk` exception if when the run was unsuccessful. - )pbdoc"); + The user should pass a serialized `QuantizationOptions` for the + `quantization_options_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + `function_aliases` maps actual function names to the function aliases, as + defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the + input SavedModel. + )pbdoc", + py::arg("src_saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_options_serialized"), py::kw_only(), + py::arg("signature_keys"), py::arg("signature_def_map_serialized"), + py::arg("function_aliases"), py::arg("py_function_library")); m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange "quantize_ptq_dynamic_range", - [](const absl::string_view saved_model_path, + [](const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + const QuantizationOptions& quantization_options, const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quant_opts, - const absl::flat_hash_map& function_aliases) - -> absl::StatusOr { - return QuantizePtqDynamicRange(saved_model_path, signature_keys, tags, - quant_opts, function_aliases); + const absl::flat_hash_map& + signature_def_map, + const absl::flat_hash_map& function_aliases, + const PyFunctionLibrary& py_function_library) -> absl::Status { + // LINT.ThenChange(pywrap_quantize_model.pyi:quantize_ptq_dynamic_range) + std::unordered_set tags; + tags.insert(quantization_options.tags().begin(), + quantization_options.tags().end()); + + const absl::StatusOr exported_model = + QuantizePtqDynamicRange(src_saved_model_path, signature_keys, tags, + quantization_options, function_aliases); + + // Remove the `tpu` tag from the debug quantized saved model as it is + // for CPU. Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + if (quantization_options.has_debugger_options()) { + tags.erase("tpu"); + } + py_function_library.SaveExportedModel( + dst_saved_model_path, *exported_model, src_saved_model_path, tags, + signature_def_map); + + return absl::OkStatus(); }, R"pbdoc( - Returns serialized ExportedModel that contains the quantized model's - GraphDef and metadata. The user should pass a serialized - `QuantizationOptions` for the `quant_opts` argument. + Quantizes a model saved at `src_saved_model_path` using dynamic-range + quantization algorithm. The resulting model will be saved to + `dst_saved_model_path`. Returns an OK sataus when successful, otherwise + raises `StatusNotOk` exception. - Raises `StatusNotOk` exception if when the run was unsuccessful. - )pbdoc"); + The user should pass a serialized `QuantizationOptions` for the + `quantization_options_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + `function_aliases` maps actual function names to the function aliases, as + defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the + input SavedModel. + )pbdoc", + py::arg("src_saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_options_serialized"), py::kw_only(), + py::arg("signature_keys"), py::arg("signature_def_map_serialized"), + py::arg("function_aliases"), py::arg("py_function_library")); m.def( + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange "quantize_weight_only", - [](const absl::string_view saved_model_path, - const QuantizationOptions& quant_opts, - const absl::flat_hash_map& function_aliases) - -> absl::StatusOr { - return QuantizeWeightOnly(saved_model_path, quant_opts, - function_aliases); + [](const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + const QuantizationOptions& quantization_options, + const absl::flat_hash_map& + signature_def_map, + const absl::flat_hash_map& function_aliases, + const PyFunctionLibrary& py_function_library) -> absl::Status { + // LINT.ThenChange(pywrap_quantize_model.pyi:quantize_weight_only) + const absl::StatusOr exported_model = QuantizeWeightOnly( + src_saved_model_path, quantization_options, function_aliases); + if (!exported_model.ok()) return exported_model.status(); + + std::unordered_set tags; + tags.insert(quantization_options.tags().begin(), + quantization_options.tags().end()); + + py_function_library.SaveExportedModel( + dst_saved_model_path, *exported_model, src_saved_model_path, tags, + signature_def_map); + + return absl::OkStatus(); }, R"pbdoc( - Returns serialized ExportedModel that contains the quantized model's - GraphDef and metadata. The user should pass a serialized - `QuantizationOptions` for the `quant_opts` argument. + Quantizes a model saved at `src_saved_model_path` using weight-only + quantization algorithm. The resulting model will be saved to + `dst_saved_model_path`. Returns an OK sataus when successful, otherwise + raises `StatusNotOk` exception. - Raises `StatusNotOk` exception if when the run was unsuccessful. - )pbdoc"); + The user should pass a serialized `QuantizationOptions` for the + `quantization_options_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + `function_aliases` maps actual function names to the function aliases, as + defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the + input SavedModel. + )pbdoc", + py::arg("src_saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_options_serialized"), py::kw_only(), + py::arg("signature_def_map_serialized"), py::arg("function_aliases"), + py::arg("py_function_library")); m.def( - "quantize_ptq_model_pre_calibration", - [](const absl::string_view saved_model_path, + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange + "quantize_ptq_static_range", + [](const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + const QuantizationOptions& quantization_options, const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quant_opts, + const absl::flat_hash_map& + signature_def_map, const absl::flat_hash_map& function_aliases, - const PyFunctionLibrary& py_function_lib) - -> absl::StatusOr { - const absl::StatusOr exported_model = - QuantizePtqModelPreCalibration(saved_model_path, signature_keys, - tags, quant_opts, function_aliases); - if (!exported_model.ok()) { - return exported_model.status(); + const PyFunctionLibrary& py_function_library, + py::object representative_dataset) -> absl::Status { + // LINT.ThenChange(pywrap_quantize_model.pyi:quantize_ptq_model_static_range) + std::unordered_set tags; + tags.insert(quantization_options.tags().begin(), + quantization_options.tags().end()); + + absl::StatusOr exported_model = + QuantizePtqModelPreCalibration(src_saved_model_path, signature_keys, + tags, quantization_options, + function_aliases); + if (!exported_model.ok()) return exported_model.status(); + + AssignIdsToCustomAggregatorOps(*exported_model->mutable_graph_def()); + + const absl::StatusOr precalibrated_saved_model_dir = + CreateTmpDir(); + if (!precalibrated_saved_model_dir.ok()) { + throw py::value_error( + precalibrated_saved_model_dir.status().ToString()); } - return py_function_lib.AssignIdsToCustomAggregatorOps(*exported_model); - }, - R"pbdoc( - Returns serialized ExportedModel that contains the model's GraphDef and - metadata. The GraphDef contains extra ops required for calibration. The - user should pass a serialized `QuantizationOptions` for the `quant_opts` - argument. + py_function_library.SaveExportedModel( + *precalibrated_saved_model_dir, *exported_model, + src_saved_model_path, tags, signature_def_map); - The argument `custom_aggregator_id_assigner` is an instance of - `CustomAggregatorIdAssigner` whose virtual function `assign_ids` is - implemented in python. + py_function_library.RunCalibration( + *precalibrated_saved_model_dir, signature_keys, tags, + quantization_options.calibration_options(), + quantization_options.force_graph_mode_calibration(), + representative_dataset); - Raises `StatusNotOk` exception if when the run was unsuccessful. - )pbdoc"); + if (absl::Status status = AddCalibrationStatistics( + *exported_model->mutable_graph_def(), + quantization_options.calibration_options(), + py_function_library); + !status.ok()) { + LOG(WARNING) << "Some CustomAggregator ops do not have min or max " + "values. Parts of the graph are not quantized. " + << status; + } - m.def( - "quantize_ptq_model_post_calibration", - [](const absl::string_view saved_model_path, - const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quant_opts, - const absl::flat_hash_map& function_aliases) - -> absl::StatusOr { - return QuantizePtqModelPostCalibration(saved_model_path, signature_keys, - tags, quant_opts, - function_aliases); + if (quantization_options.has_debugger_options()) { + EnableDebugging(*exported_model, + quantization_options.debugger_options(), + py_function_library, src_saved_model_path, tags, + signature_def_map); + } + + const absl::StatusOr calibrated_saved_model_path = + CreateTmpDir(); + if (!calibrated_saved_model_path.ok()) { + throw py::value_error( + calibrated_saved_model_path.status().ToString()); + } + + py_function_library.SaveExportedModel( + *calibrated_saved_model_path, *exported_model, src_saved_model_path, + tags, signature_def_map); + + const absl::flat_hash_map + function_aliases_after_calibration( + exported_model->function_aliases().begin(), + exported_model->function_aliases().end()); + + const absl::StatusOr post_calibrated_exported_model = + QuantizePtqModelPostCalibration( + *calibrated_saved_model_path, signature_keys, tags, + quantization_options, function_aliases_after_calibration); + if (!post_calibrated_exported_model.ok()) + return post_calibrated_exported_model.status(); + + // Remove the `tpu` tag from the debug quantized saved model as it is + // for CPU. Note the 'tpu' value should be the same as `TPU` defined in + // tensorflow/python/saved_model/tag_constants.py. + if (quantization_options.has_debugger_options()) { + tags.erase("tpu"); + } + py_function_library.SaveExportedModel( + dst_saved_model_path, *post_calibrated_exported_model, + *calibrated_saved_model_path, tags, signature_def_map); + + return absl::OkStatus(); }, R"pbdoc( - Returns serialized ExportedModel that contains the quantized model's - GraphDef and metadata. The user should pass a serialized - `QuantizationOptions` for the `quant_opts` argument. + Runs static-range post-training quantization (PTQ) on a SavedModel at + `src_saved_model_path` and saves the resulting model to + `dst_saved_model_path`. + + The user should pass a serialized `QuantizationOptions` for the + `quantization_options_serialized` argument, and a signature key -> + serialized `SignatureDef` mapping for the `signature_def_map_serialized` + argument. + + `function_aliases` maps actual function names to the function aliases, as + defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the + input SavedModel. Raises `StatusNotOk` exception if when the run was unsuccessful. - )pbdoc"); + )pbdoc", + py::arg("saved_model_path"), py::arg("dst_saved_model_path"), + py::arg("quantization_options_serialized"), py::kw_only(), + py::arg("signature_keys"), py::arg("signature_def_map_serialized"), + py::arg("function_aliases"), py::arg("py_function_library"), + py::arg("representative_dataset")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi index 6e47f029f5e4d9..afe61d54854e71 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi @@ -12,45 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import Any + from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd -def clear_calibrator() -> None: ... -def clear_data_from_calibrator(id: bytes) -> None: ... -def get_statistics_from_calibrator( - id: bytes, -) -> calibration_statistics_pb2.CalibrationStatistics: ... +# LINT.IfChange(quantize_qat_model) def quantize_qat_model( - saved_model_path: str, - signature_keys: list[str], - tags: set[str], + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], function_aliases: dict[str, str], -) -> bytes: ... + py_function_library: py_function_lib.PyFunctionLibrary, +) -> Any: ... # Status + +# LINT.ThenChange() + +# LINT.IfChange(quantize_ptq_dynamic_range) def quantize_ptq_dynamic_range( - saved_model_path: str, - signature_keys: list[str], - tags: set[str], + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], function_aliases: dict[str, str], -) -> bytes: ... + py_function_library: py_function_lib.PyFunctionLibrary, +) -> Any: ... # Status + +# LINT.ThenChange() + +# LINT.IfChange(quantize_weight_only) def quantize_weight_only( - saved_model_path: str, - quantization_options_serialized: bytes, - function_aliases: dict[str, str], -) -> bytes: ... -def quantize_ptq_model_pre_calibration( - saved_model_path: str, - signature_keys: list[str], - tags: set[str], + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options_serialized: bytes, + *, + signature_def_map_serialized: dict[str, bytes], function_aliases: dict[str, str], py_function_library: py_function_lib.PyFunctionLibrary, -) -> bytes: ... -def quantize_ptq_model_post_calibration( - saved_model_path: str, - signature_keys: list[str], - tags: set[str], +) -> Any: ... # Status + +# LINT.ThenChange() + +# LINT.IfChange(quantize_ptq_static_range) +def quantize_ptq_static_range( + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], function_aliases: dict[str, str], -) -> bytes: ... + py_function_library: py_function_lib.PyFunctionLibrary, + representative_dataset: rd.RepresentativeDatasetOrMapping, +) -> Any: ... # Status + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py index ed531218290c7b..b29edcfaed4c9c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py @@ -17,6 +17,7 @@ These test cases are mostly for validation checks. Tests for functionalities are at `quantize_model_test.py`. """ +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model from tensorflow.python.platform import test @@ -25,25 +26,39 @@ class PywrapQuantizeModelTest(test.TestCase): """Test cases for quantize_model python wrappers.""" def test_quantize_model_fails_when_invalid_quant_options_serialization(self): - saved_model_path = self.create_tempdir('saved_model').full_path + src_saved_model_path = self.create_tempdir().full_path + dst_saved_model_path = self.create_tempdir().full_path signature_def_keys = ['serving_default'] - tags = {'serve'} - quant_opts_serialized = 'invalid protobuf serialization string' + quant_opts_serialized = 'invalid proto serialization string'.encode('utf-8') with self.assertRaisesRegex(TypeError, 'incompatible function arguments'): - pywrap_quantize_model.quantize_ptq_model_pre_calibration( - saved_model_path, signature_def_keys, tags, quant_opts_serialized + pywrap_quantize_model.quantize_ptq_static_range( + src_saved_model_path, + dst_saved_model_path, + quant_opts_serialized, + signature_keys=signature_def_keys, + signature_def_map_serialized={}, + function_aliases={}, + py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset=None, ) def test_quantize_model_fails_when_invalid_quant_options_type(self): - saved_model_path = self.create_tempdir('saved_model').full_path + src_saved_model_path = self.create_tempdir().full_path + dst_saved_model_path = self.create_tempdir().full_path signature_def_keys = ['serving_default'] - tags = {'serve'} invalid_quant_opts_object = ('a', 'b', 'c') with self.assertRaisesRegex(TypeError, 'incompatible function arguments'): - pywrap_quantize_model.quantize_ptq_model_pre_calibration( - saved_model_path, signature_def_keys, tags, invalid_quant_opts_object + pywrap_quantize_model.quantize_ptq_static_range( + src_saved_model_path, + dst_saved_model_path, + invalid_quant_opts_object, + signature_keys=signature_def_keys, + signature_def_map_serialized={}, + function_aliases={}, + py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset=None, ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 4054ce5ab6f354..ab4f3327956cd0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -22,10 +22,14 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" @@ -39,20 +43,22 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/export.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/precalibration.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/status_macro.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -63,8 +69,7 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/status.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -73,66 +78,16 @@ namespace { using ::mlir::quant::kTfFilePrefix; using ::mlir::quant::kTfQuantSaveOpName; +using ::mlir::quant::stablehlo::PreCalibrationComponent; using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; - -// Suffix string for the module export step. Used for debugging. -constexpr absl::string_view kExportStepSuffix = "_export"; - -// Options when running passes for exporting an MLIR ModuleOp. -struct ExportOptions { - // If set to `true`, it runs `DuplicateShapeDeterminingConstantsPass` before - // lowering to tf_executor dialect. - bool duplicate_shape_determining_constants = true; - - // If set to `true`, unfreezes constants into variables and saves them to a - // checkpoint file. Setting this to `true` is an experimental feature that has - // no stability guarantees. - bool unfreeze_constants = false; - - // Path to the directory where checkpoint files are saved. - std::string checkpoint_dir = ""; - - // Name used to identify the ModuleOp this is exporting. Only used for - // debugging and does not modify the behavior of the export. - std::string debug_name = "tf_quant"; -}; - -// Add passes for transforming the MLIR module op so that it can be exported -// back to GraphDef. Roughly, this consists of: -// 1) Inserting the @main function, which will become the main Graph. -// 2) Duplicating shape-determining constants. -// 3) Converting TF dialect -> tf_executor dialect. -// 4) Adding initializer function's ops into @main function for correct -// resource initialization when loading the exported model. -// -// Duplicating shape-determining constants is required to place constants that -// affect the shape of a tensor to be placed in the TPU graph instead of in the -// CPU graph, when the graph gets converted for TPU inference. This allows these -// constants to be known at XLA compilation time. -void AddExportPasses(const bool duplicate_shape_determining_constants, - mlir::PassManager &pm) { - if (duplicate_shape_determining_constants) { - pm.addNestedPass( - mlir::quant::CreateDuplicateShapeDeterminingConstantsPass()); - } - - pm.addPass(mlir::quant::CreateInsertMainFunctionPass()); - pm.addPass(mlir::quant::CreateLiftHashTableOpsAsArgsPass()); - pm.addNestedPass( - mlir::CreateFunctionalToExecutorDialectConversionPass()); - pm.addPass(mlir::CreateBreakUpIslandsPass()); - pm.addPass(mlir::quant::CreateMergeInitializerFunctionOpsToMainPass()); - pm.addPass(mlir::quant::CreateMergeSaveFunctionOpsToMainPass()); - pm.addNestedPass( - mlir::quant::CreateMergeDuplicateResourceOpsPass()); - - // Used to clean up the "tf._noinliner" attribute that is previously used to - // prevent certain functions from being inlined (see - // `MarkFunctionsNoinlinePass`). InlinerPass must not come after this pass. - pm.addPass(mlir::TF::CreateStripNoinlineAttributePass()); -} +using ::stablehlo::quantization::AddExportPasses; +using ::stablehlo::quantization::CreateExportedModel; +using ::stablehlo::quantization::ExportOptions; +using ::stablehlo::quantization::kExportStepSuffix; +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::io::GetLocalTmpFileName; // Finds and returns the name of the node from a set of control output nodes. // The name should contain the string `contains`. Returns an empty string if no @@ -151,32 +106,6 @@ std::string GetNodeName(const absl::flat_hash_set &control_ret_nodes, return ""; } -// Factory function for `ExportedModel`. -[[nodiscard]] ExportedModel CreateExportedModel( - GraphDef &&graph_def, const absl::string_view init_node_name, - const absl::string_view checkpoint_dir, - const std::optional saver_def, - const absl::flat_hash_map &function_aliases, - const std::vector &asset_file_defs) { - ExportedModel exported_model{}; - *exported_model.mutable_graph_def() = graph_def; - exported_model.set_init_node_name(std::string(init_node_name)); - exported_model.set_checkpoint_dir(std::string(checkpoint_dir)); - - exported_model.mutable_function_aliases()->insert(function_aliases.begin(), - function_aliases.end()); - - for (const auto &asset_file_def : asset_file_defs) { - *exported_model.mutable_asset_file_defs()->Add() = asset_file_def; - } - - if (saver_def != std::nullopt) { - *exported_model.mutable_saver_def() = *std::move(saver_def); - } - - return exported_model; -} - // Returns the file prefix tensor name. An empty string is returned if no such a // tensor is found (when there are no variables to restore, it is expected that // the file prefix tensor does not exist). The file prefix tensor is found among @@ -197,7 +126,7 @@ std::string FindFilePrefixTensorName(const GraphDef &graph_def) { if (const auto file_prefix_itr = absl::c_find(index_paths, kTfFilePrefix.str()); file_prefix_itr != index_paths.end()) { - // ":0" appended to inidicate that it is a tensor, not an Operation. + // ":0" appended to indicate that it is a tensor, not an Operation. return absl::StrCat(node_def.name(), ":0"); } } @@ -322,60 +251,6 @@ absl::flat_hash_map UpdateFunctionAliases( return updated_function_aliases; } -// Create a unique local temporary filename. It only creates the name, not the -// actual file. -absl::StatusOr GetLocalTempFilename() { - auto *env = Env::Default(); - std::string tmp_fname{}; - if (!env->LocalTempFilename(&tmp_fname)) { - return absl::InternalError("Failed to create a local temp file name."); - } - - return tmp_fname; -} - -// Unfreezes constants into variables and saves them to a checkpoint files under -// `checkpoint_dir`. `checkpoint_dir` will be created within this function. It -// will return a non-OK status if it already exists or permission is denied. -// TODO(b/261652258): Make sure this works for when there are non-frozen -// variables in the model. -// TODO(b/262189534): Move this to a separate file for better testing. -absl::Status UnfreezeConstantsAndSaveVariables( - const absl::string_view checkpoint_dir, mlir::MLIRContext &ctx, - mlir::ModuleOp module_op) { - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantConstantUnfreezingStepName, - /*add_passes_func=*/ - [](mlir::PassManager &pm) { - pm.addPass(mlir::quant::CreateUnfreezeConstantsPass()); - }, - ctx, module_op)); - - if (const tsl::Status create_dir_status = - Env::Default()->CreateDir(std::string(checkpoint_dir)); - !create_dir_status.ok()) { - LOG(ERROR) << "Failed to create checkpoint directory at: " - << checkpoint_dir; - return create_dir_status; - } - - TF_ASSIGN_OR_RETURN(const auto _, - SaveVariablesToCheckpoint(checkpoint_dir, module_op)); - - return RunPasses( - /*name=*/kTfQuantInsertRestoreOpStepName, - /*add_passes_func=*/ - [](mlir::PassManager &pm) { - pm.addPass(mlir::quant::CreateInsertRestoreOpPass()); - pm.addPass(mlir::quant::CreateInsertSaveOpPass()); - // Initialization by `tf.ConstOp` is no longer required as there is - // a `tf.RestoreV2Op` now. - pm.addPass( - mlir::quant::CreateRemoveVariableInitializationByConstPass()); - }, - ctx, module_op); -} - // Sets up and runs the passes for exporting `module_op`. The behavior of the // exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that // associate the input arguments of @main and the asset file names. Asset file @@ -385,17 +260,17 @@ absl::StatusOr> RunExportPasses( const ExportOptions &export_opts, mlir::MLIRContext &ctx, mlir::ModuleOp module_op) { if (export_opts.unfreeze_constants) { - TF_QUANT_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( + TF_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( export_opts.checkpoint_dir, ctx, module_op)); LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " << export_opts.checkpoint_dir; } - if (const absl::Status pass_run_status = RunPasses( + if (absl::Status pass_run_status = RunPasses( /*name=*/export_opts.debug_name, /*add_passes_func=*/ [dup_constants = export_opts.duplicate_shape_determining_constants]( - mlir::PassManager &pm) { AddExportPasses(dup_constants, pm); }, + mlir::PassManager &pm) { AddExportPasses(pm, dup_constants); }, ctx, module_op); !pass_run_status.ok()) { return pass_run_status; @@ -462,15 +337,14 @@ absl::StatusOr QuantizeQatModel( return aliased_function_names.insert(aliases.first); }); - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, /*is_inliner_run=*/true, /*noinline_functions=*/aliased_function_names, module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr, /*run_tf_to_stablehlo=*/false)); - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantQatStepName, - /*add_passes_func=*/ + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kTfQuantQatStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizeQatPasses(pm, quantization_options, kTfQuantQatStepName); }, @@ -478,7 +352,7 @@ absl::StatusOr QuantizeQatModel( const bool unfreeze_constants = !quantization_options.freeze_all_variables(); - TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTempFilename()); + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); const auto export_opts = ExportOptions{ /*duplicate_shape_determining_constants=*/true, unfreeze_constants, @@ -533,25 +407,21 @@ absl::StatusOr QuantizePtqModelPreCalibration( const bool run_tf_to_stablehlo = (quantization_options.op_set() == tensorflow::quantization::OpSet::STABLEHLO); - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( /*mlir_dump_file_prefix=*/kTfQuantPtqPreCalibrationStepName, - /*is_inliner_run=*/true, - /*noinline_functions=*/aliased_function_names, module_ref.get(), &context, - bundle ? bundle->GetSession() : nullptr, run_tf_to_stablehlo)); + /*is_inliner_run=*/true, /*noinline_functions=*/aliased_function_names, + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr, + run_tf_to_stablehlo)); // Use StableHLO Quantizer option if opset is specified. if (run_tf_to_stablehlo) { - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqPreCalibrationStepStableHloName, - /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizePtqPreCalibrationStablehloPasses(pm, quantization_options); - }, - context, *module_ref)); + PreCalibrationComponent pre_calibration_component( + &context, quantization_options.calibration_options()); + TF_ASSIGN_OR_RETURN(*module_ref, pre_calibration_component.Run( + *module_ref, QuantizationConfig())); } else { - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqPreCalibrationStepName, - /*add_passes_func=*/ + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kTfQuantPtqPreCalibrationStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqPreCalibrationPasses(pm, quantization_options); }, @@ -559,7 +429,7 @@ absl::StatusOr QuantizePtqModelPreCalibration( } const bool unfreeze_constants = !quantization_options.freeze_all_variables(); - TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTempFilename()); + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); // `duplicate_shape_determining_constants = false` because the // resulting graph of this step is not expected to be loaded on TPU. @@ -619,28 +489,26 @@ absl::StatusOr QuantizePtqModelPostCalibration( // Freezing is required again since variables might have been produced during // the pre-calibration step. `is_inliner_run = false` to prevent the functions // lifted for quantization from being inlined. - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( /*mlir_dump_file_prefix=*/kTfQuantPtqPostCalibrationStepName, - /*is_inliner_run=*/false, - /*noinline_functions=*/aliased_function_names, module_ref.get(), &context, - bundle ? bundle->GetSession() : nullptr, /*run_tf_to_stablehlo=*/false)); + /*is_inliner_run=*/false, /*noinline_functions=*/aliased_function_names, + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr, + /*run_tf_to_stablehlo=*/false)); // Use StableHLO Quantizer option if opset is specified. if (quantization_options.op_set() == tensorflow::quantization::OpSet::STABLEHLO) { - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqPostCalibrationStepStableHloName, - /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizePtqPostCalibrationStablehloPasses( - pm, quantization_options, - kTfQuantPtqPostCalibrationStepStableHloName); - }, - context, *module_ref)); + TF_RETURN_IF_ERROR( + RunPasses(/*name=*/kTfQuantPtqPostCalibrationStepStableHloName, + /*add_passes_func=*/ + [](mlir::PassManager &pm) { + AddQuantizePtqPostCalibrationStablehloPasses( + pm, kTfQuantPtqPostCalibrationStepStableHloName); + }, + context, *module_ref)); } else { - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqPostCalibrationStepName, - /*add_passes_func=*/ + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kTfQuantPtqPostCalibrationStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqPostCalibrationPasses( pm, quantization_options, kTfQuantPtqPostCalibrationStepName); @@ -649,7 +517,7 @@ absl::StatusOr QuantizePtqModelPostCalibration( } const bool unfreeze_constants = !quantization_options.freeze_all_variables(); - TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTempFilename()); + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); const auto export_opts = ExportOptions{ /*duplicate_shape_determining_constants=*/true, unfreeze_constants, @@ -705,15 +573,14 @@ absl::StatusOr QuantizePtqDynamicRange( return aliased_function_names.insert(aliases.first); }); - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, - /*is_inliner_run=*/true, - /*noinline_functions=*/aliased_function_names, module_ref.get(), &context, - bundle ? bundle->GetSession() : nullptr, /*run_tf_to_stablehlo=*/false)); + /*is_inliner_run=*/true, /*noinline_functions=*/aliased_function_names, + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr, + /*run_tf_to_stablehlo=*/false)); - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqDynamicRangeStepName, - /*add_passes_func=*/ + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kTfQuantPtqDynamicRangeStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqDynamicRangePasses(pm, quantization_options, kTfQuantPtqDynamicRangeStepName); @@ -721,7 +588,7 @@ absl::StatusOr QuantizePtqDynamicRange( context, *module_ref)); const bool unfreeze_constants = !quantization_options.freeze_all_variables(); - TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTempFilename()); + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); const auto export_opts = ExportOptions{ /*duplicate_shape_determining_constants=*/true, unfreeze_constants, @@ -780,23 +647,22 @@ absl::StatusOr QuantizeWeightOnly( return aliased_function_names.insert(aliases.first); }); - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, /*is_inliner_run=*/true, /*noinline_functions=*/aliased_function_names, module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr, /*run_tf_to_stablehlo=*/false)); - TF_QUANT_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantWeightOnlyStepName, - /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizeWeightOnlyPasses(pm, quantization_options, - kTfQuantWeightOnlyStepName); - }, - context, *module_ref)); + TF_RETURN_IF_ERROR( + RunPasses(/*name=*/kTfQuantWeightOnlyStepName, /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizeWeightOnlyPasses(pm, quantization_options, + kTfQuantWeightOnlyStepName); + }, + context, *module_ref)); const bool unfreeze_constants = !quantization_options.freeze_all_variables(); - TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTempFilename()); + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); const auto export_opts = ExportOptions{ /*duplicate_shape_determining_constants=*/true, unfreeze_constants, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index 4db0f667a619d5..81e5b6167fc0e3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -39,14 +39,8 @@ inline constexpr absl::string_view kTfQuantPtqDynamicRangeStepName = "tf_quant_ptq_dynamic_range"; inline constexpr absl::string_view kTfQuantWeightOnlyStepName = "tf_quant_weight_only"; -inline constexpr absl::string_view kTfQuantConstantUnfreezingStepName = - "tf_quant_constant_unfreezing"; -inline constexpr absl::string_view kTfQuantInsertRestoreOpStepName = - "tf_quant_insert_restore_op"; // StableHLO Quantization passes that are ran if StableHLO opset is selected. -inline constexpr absl::string_view kTfQuantPtqPreCalibrationStepStableHloName = - "tf_quant_ptq_pre_calibration_stablehlo"; inline constexpr absl::string_view kTfQuantPtqPostCalibrationStepStableHloName = "tf_quant_ptq_post_calibration_stablehlo"; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 3746afa13b8dbe..affc19a9250890 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -14,36 +14,23 @@ # ============================================================================== """Defines TF Quantization API from SavedModel to SavedModel.""" -import collections.abc import tempfile -from typing import Callable, Collection, Dict, Mapping, Optional, Sequence +from typing import Mapping, Optional from absl import logging -import numpy as np -from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stats_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model -from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.core.protobuf import saver_pb2 -from tensorflow.python.client import session -from tensorflow.python.eager import context -from tensorflow.python.eager import wrap_function -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_conversion from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import load as saved_model_load from tensorflow.python.saved_model import loader_impl as saved_model_loader from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.trackable import autotrackable -from tensorflow.python.types import core from tensorflow.python.util import tf_export # Type aliases for quant_opts_pb2 messages. @@ -76,10 +63,6 @@ # during dynamic range quantization (DRQ) and weight-only quantization. _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS = 1024 -# Name of the saved model assets directory. -_ASSETS_DIR = 'assets' -_ASSETS_EXTRA_DIR = 'assets.extra' - def _is_qat_saved_model(saved_model_path: str): """Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops.""" @@ -95,485 +78,22 @@ def _is_qat_saved_model(saved_model_path: str): return False -def _create_sample_validator( - expected_input_keys: Collection[str], -) -> Callable[ - [repr_dataset.RepresentativeSample], repr_dataset.RepresentativeSample -]: - """Creates a validator function for a representative sample. - - Args: - expected_input_keys: Input keys (keyword argument names) that the function - the sample will be used for is expecting to receive. - - Returns: - A callable that validates a `RepresentativeSample`. - """ - - def validator( - sample: repr_dataset.RepresentativeSample, - ) -> repr_dataset.RepresentativeSample: - """Validates a single instance of representative sample. - - This provides a simple check for `sample` that this is a mapping of - {input_key: input_value}. - - Args: - sample: A `RepresentativeSample` to validate. - - Returns: - `sample` iff it is valid. - - Raises: - ValueError: iff the sample isn't an instance of `Mapping`. - KeyError: iff the sample does not have the set of input keys that match - the input keys of the function. - """ - if not isinstance(sample, collections.abc.Mapping): - raise ValueError( - 'Invalid representative sample type. Provide a mapping ' - '(usually a dict) of {input_key: input_value}. ' - f'Got type: {type(sample)} instead.' - ) - - if set(sample.keys()) != expected_input_keys: - raise KeyError( - 'Invalid input keys for representative sample. The function expects ' - f'input keys of: {set(expected_input_keys)}. ' - f'Got: {set(sample.keys())}. Please provide correct input keys for ' - 'representative samples.' - ) - - return sample - - return validator - - -def _validate_representative_dataset( - representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, - signature_keys: Collection[str], -) -> None: - """Validates the representative dataset, based on the signature keys. - - Representative dataset can be provided in two different forms: a single - instance of `RepresentativeDataset` or a map of signature key to the - corresponding `RepresentativeDataset`. These have a relationship with - `signature_keys`. - - This function validates the following conditions: - * If `len(signature_keys) > 1`, then `representative_dataset` should be a - mapping where the keys exactly match the elements in `signature_keys`. - * If `len(signature_keys) == 1`, then both a mapping and a single instance of - `RepresentativeDataset` are allowed. - * This function also assumes `len(signature_keys) > 0`. - - Args: - representative_dataset: A `RepresentativeDataset` or a map of string to - `RepresentativeDataset` to be validated. - signature_keys: A collection of strings that contains the signature keys, - each identifying a `SignatureDef`. - - Raises: - ValueError: Iff `representative_dataset` does not satisfy the conditions - above. - """ - if isinstance(representative_dataset, collections.abc.Mapping): - if set(signature_keys) != set(representative_dataset.keys()): - raise ValueError( - 'The signature keys and the keys of representative dataset map ' - f'do not match. Signature keys: {set(signature_keys)}, ' - f'representative dataset map: {set(representative_dataset.keys())}.' - ) - else: - if len(signature_keys) > 1: - raise ValueError( - 'Representative dataset is not a mapping ' - f'(got: {type(representative_dataset)}), ' - 'but there is more than one signature key provided. ' - 'Please provide a map of {signature_key -> dataset} ' - 'with more than one signature key.' - ) - - -def _convert_values_to_tf_tensors( - sample: repr_dataset.RepresentativeSample, -) -> Mapping[str, core.Tensor]: - """Converts TensorLike values of `sample` to Tensors. - - Creates a copy of `sample`, where each value is converted to Tensors - unless it is already a Tensor. - The values are not converted in-place (i.e. `sample` is not mutated). - - Args: - sample: A representative sample, which is a map of {name -> tensorlike - value}. - - Returns: - Converted map of {name -> tensor}. - """ - tensor_mapping = {} - for name, tensorlike_value in sample.items(): - if isinstance(tensorlike_value, core.Tensor): - tensor_value = tensorlike_value - else: - tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch( - tensorlike_value - ) - - tensor_mapping[name] = tensor_value - - return tensor_mapping - - -def _create_feed_dict_from_input_data( - input_data: repr_dataset.RepresentativeSample, - signature_def: meta_graph_pb2.SignatureDef, -) -> Dict[str, np.ndarray]: - """Constructs a feed_dict from input data. - - Note: This function should only be used in graph mode. - - This is a helper function that converts an 'input key -> input value' mapping - to a feed dict. A feed dict is an 'input tensor name -> input value' mapping - and can be directly passed to the `feed_dict` argument of `sess.run()`. +def _serialize_signature_def_map( + signature_def_map: _SignatureDefMap, +) -> dict[str, bytes]: + """Serializes SignatureDef values in `signature_def_map`. Args: - input_data: Input key -> input value mapping. The input keys should match - the input keys of `signature_def`. - signature_def: A SignatureDef representing the function that `input_data` is - an input to. + signature_def_map: Signature key -> SignatureDef mapping. Returns: - Feed dict, which is intended to be used as input for `sess.run`. It is - essentially a mapping: input tensor name -> input value. Note that the input - value in the feed dict is not a `Tensor`. + Signature def map where the values (`SignatureDef`) are serialized. """ - feed_dict = {} - for input_key, input_value in input_data.items(): - input_tensor_name = signature_def.inputs[input_key].name - - value = input_value - if isinstance(input_value, core.Tensor): - # Take the data out of the tensor. - value = input_value.eval() - - feed_dict[input_tensor_name] = value - - return feed_dict - - -# TODO(b/249918070): Implement a progress bar. -def _log_sample_num_for_calibration( - representative_dataset: repr_dataset.RepresentativeDataset, -) -> repr_dataset.RepresentativeDataset: - """Logs the sample number for calibration. + signature_def_map_serialized = {} + for key, signature_def in signature_def_map.items(): + signature_def_map_serialized[key] = signature_def.SerializeToString() - If in debug logging level, the "sample number / total num samples" is logged - for every 5 iterations. - - This is often useful when tracking the progress of the calibration step which - is often slow and may look stale if there's no logs being printed. - - Args: - representative_dataset: The representative dataset. - - Yields: - The representative samples from `representative_dataset` without any - modification. - """ - num_samples: Optional[int] = repr_dataset.get_num_samples( - representative_dataset - ) - if num_samples is None: - total_num_samples = '?' - logging.info('Representative dataset size unknown.') - else: - total_num_samples = str(num_samples) - logging.info('Using representative dataset of size: %s', total_num_samples) - - sample_num = 0 - for sample in representative_dataset: - sample_num += 1 - - # Log the sample number for every 5 iterations. - logging.log_every_n( - logging.DEBUG, - 'Running representative sample for calibration: %d / %s', - 5, - sample_num, - total_num_samples, - ) - yield sample - - logging.info( - 'Running representative samples complete: %d / %s', - sample_num, - total_num_samples, - ) - - -def _run_function_for_calibration_graph_mode( - sess: session.Session, - signature_def: meta_graph_pb2.SignatureDef, - representative_dataset: repr_dataset.RepresentativeDataset, -) -> None: - """Runs the representative dataset through a function for calibration. - - NOTE: This is intended to be run in graph mode (TF1). - - The function is identified by the SignatureDef. - - Args: - sess: The Session object to run the function in. - signature_def: A SignatureDef that identifies a function by specifying the - inputs and outputs. - representative_dataset: The representative dataset to run through the - function. - """ - output_tensor_names = [ - output_tensor_info.name - for output_tensor_info in signature_def.outputs.values() - ] - - sample_validator = _create_sample_validator( - expected_input_keys=signature_def.inputs.keys() - ) - - for sample in map( - sample_validator, _log_sample_num_for_calibration(representative_dataset) - ): - # Create a mapping from input tensor name to the input tensor value. - # ex) "Placeholder:0" -> [0, 1, 2] - feed_dict = _create_feed_dict_from_input_data(sample, signature_def) - sess.run(output_tensor_names, feed_dict=feed_dict) - - -def _replace_tensors_by_numpy_ndarrays( - repr_ds_map: repr_dataset.RepresentativeDatasetMapping, -) -> None: - """Replaces tf.Tensors by their evaluated numpy arrays. - - This assumes that tf.Tensors in representative samples are created in the - default Graph. It will raise an error if tensors are created in a different - graph. - - Args: - repr_ds_map: SignatureDef key -> RepresentativeDataset mapping. - """ - with session.Session() as sess: - for signature_def_key in repr_ds_map: - # Replaces the dataset with a new dataset where tf.Tensors are replaced - # by their evaluated values. - ds = repr_ds_map[signature_def_key] - repr_ds_map[signature_def_key] = ( - repr_dataset.replace_tensors_by_numpy_ndarrays(ds, sess) - ) - - -def _run_graph_for_calibration_graph_mode( - model_dir: str, - tags: Collection[str], - representative_dataset_map: repr_dataset.RepresentativeDatasetMapping, -) -> None: - """Runs the graph for calibration in graph mode. - - This function assumes _graph mode_ (used when legacy TF1 is used or when eager - mode is explicitly disabled) when running the graph. This step is used in - order to collect the statistics in CustomAggregatorOp for quantization using - the representative dataset for the actual data provided for inference. - - Args: - model_dir: Path to SavedModel directory. - tags: Collection of tags identifying the MetaGraphDef within the SavedModel. - representative_dataset_map: A map where signature keys are mapped to - corresponding representative datasets. - - Raises: - ValueError: When running the function with the representative dataset fails. - """ - # Replace tf.Tensors by numpy ndarrays in order to reuse the samples in a - # different graph when running the calibration. - _replace_tensors_by_numpy_ndarrays(representative_dataset_map) - - # Run the calibration in a new graph to avoid name collision, which could - # happen when the same model is loaded multiple times in the default graph. - with ops.Graph().as_default(), session.Session() as sess: - meta_graph: meta_graph_pb2.MetaGraphDef = saved_model_loader.load( - sess, tags, export_dir=model_dir - ) - - for signature_key, repr_ds in representative_dataset_map.items(): - sig_def = meta_graph.signature_def[signature_key] - - try: - _run_function_for_calibration_graph_mode( - sess, signature_def=sig_def, representative_dataset=repr_ds - ) - except Exception as ex: - raise ValueError( - 'Failed to run representative dataset through the ' - f'function with the signature key: {signature_key}.' - ) from ex - - -def _run_function_for_calibration_eager_mode( - func: wrap_function.WrappedFunction, - representative_dataset: repr_dataset.RepresentativeDataset, -) -> None: - """Runs the representative dataset through a function for calibration. - - NOTE: This is intended to be run in eager mode (TF2). - - Args: - func: The function to run the representative samples through. - representative_dataset: Representative dataset used for calibration. The - input keys and input values of the representative samples should match the - keyword arguments of `func`. - """ - _, keyword_args = func.structured_input_signature - sample_validator = _create_sample_validator( - expected_input_keys=keyword_args.keys() - ) - - for sample in map( - sample_validator, _log_sample_num_for_calibration(representative_dataset) - ): - # Convert any non-Tensor values from the sample to Tensors. - # This conversion is required because the model saved in `model_dir` is - # saved using TF1 SavedModelBuilder, which doesn't save the - # SavedObjectGraph. - # TODO(b/236795224): Remove the need for this conversion by keeping the - # FunctionSpec (object graph) in the SavedModel. Related: b/213406917. - func_kwargs = _convert_values_to_tf_tensors(sample) - func(**func_kwargs) - - -def _run_graph_for_calibration_eager_mode( - model_dir: str, - tags: Collection[str], - representative_dataset_map: repr_dataset.RepresentativeDatasetMapping, -) -> None: - """Runs the graph for calibration in eager mode. - - This function assumes _eager mode_ (enabled in TF2 by default) when running - the graph. This step is used in order to collect the statistics in - CustomAggregatorOp for quantization using the representative dataset for the - actual data provided for inference. - - Args: - model_dir: Path to SavedModel directory. - tags: Collection of tags identifying the MetaGraphDef within the SavedModel. - representative_dataset_map: A map where signature keys are mapped to - corresponding representative datasets. - - Raises: - ValueError: When running the function with the representative dataset fails. - """ - root: autotrackable.AutoTrackable = saved_model_load.load(model_dir, tags) - for signature_key, repr_ds in representative_dataset_map.items(): - try: - _run_function_for_calibration_eager_mode( - func=root.signatures[signature_key], representative_dataset=repr_ds - ) - except Exception as ex: - raise ValueError( - 'Failed to run representative dataset through the ' - f'function with the signature key: {signature_key}.' - ) from ex - - -def _run_graph_for_calibration( - float_model_dir: str, - signature_keys: Sequence[str], - tags: Collection[str], - representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, - force_graph_mode_calibration: bool, -) -> None: - """Runs the graph for calibration using representative datasets. - - Args: - float_model_dir: Path to the model to calibrate. - signature_keys: Sequence of keys identifying SignatureDef containing inputs - and outputs. - tags: Collection of tags identifying the MetaGraphDef within the SavedModel - to analyze. - representative_dataset: An iterator that returns a dictionary of {input_key: - input_value} or a mapping from signature keys to such iterators. When - `signature_keys` contains more than one signature key, - `representative_datsaet` should be a mapping that maps each signature keys - to the corresponding representative dataset. - force_graph_mode_calibration: If set to true, it forces calibration in graph - model instead of eager mode when the context is in eager mode. - - Raises: - ValueError iff: - * The representative dataset format is invalid. - * It fails to run the functions using the representative datasets. - """ - try: - _validate_representative_dataset(representative_dataset, signature_keys) - except Exception as ex: - raise ValueError('Invalid representative dataset.') from ex - - # If `representative_dataset` is not a mapping, convert to a mapping for the - # following functions to handle representative datasets more conveniently. - representative_dataset_map = representative_dataset - if not isinstance(representative_dataset, collections.abc.Mapping): - # `signature_keys` is guaranteed to have only one element after the - # validation. - representative_dataset_map = {signature_keys[0]: representative_dataset} - - try: - if context.executing_eagerly() and not force_graph_mode_calibration: - logging.info('Calibration step is executed in eager mode.') - _run_graph_for_calibration_eager_mode( - float_model_dir, tags, representative_dataset_map - ) - else: - logging.info('Calibration step is executed in graph mode.') - _run_graph_for_calibration_graph_mode( - float_model_dir, tags, representative_dataset_map - ) - except Exception as ex: - raise ValueError( - 'Failed to run graph for post-training quantization calibration.' - ) from ex - - logging.info('Calibration step complete.') - - -def _copy_assets(src_path: str, dst_path: str) -> None: - """Copies the assets directory of the saved model. - - Clones the contents of the assets/ directory from the source saved model - directory to the destination saved model directory. Nothing will be copied if - there are no assets directory in the source directory. - - Args: - src_path: Source saved model directory. - dst_path: Destination saved model directory. This directory must exist. - """ - for assets_dir_name in [_ASSETS_DIR, _ASSETS_EXTRA_DIR]: - src_assets_path = file_io.join(src_path, assets_dir_name) - if not file_io.file_exists_v2(src_assets_path): - # Do nothing if the source assets path does not exist. - continue - - dst_assets_path = file_io.join(dst_path, assets_dir_name) - file_io.create_dir_v2(dst_assets_path) - - for curr_dir, _, files in file_io.walk_v2(src_assets_path): - for asset_file_name in files: - src_asset_file = file_io.join(curr_dir, asset_file_name) - - # Construct the destination assets file path. - curr_dst_dir = curr_dir.replace(src_assets_path, dst_assets_path) - dst_asset_file = file_io.join(curr_dst_dir, asset_file_name) - - file_io.copy_v2(src_asset_file, dst_asset_file) - logging.info( - 'Copied asset file: %s -> %s', src_asset_file, dst_asset_file - ) + return signature_def_map_serialized def _run_static_range_qat( @@ -599,145 +119,17 @@ def _run_static_range_qat( quant_opts.tags ).meta_info_def.function_aliases - exported_model_serialized = pywrap_quantize_model.quantize_qat_model( + pywrap_quantize_model.quantize_qat_model( src_saved_model_path, - list(quant_opts.signature_keys), - set(quant_opts.tags), - quant_opts.SerializeToString(), - dict(function_aliases), - ) - - exported_model = exported_model_pb2.ExportedModel.FromString( - exported_model_serialized - ) - - save_model.save_model_v1( - exported_model.graph_def, dst_saved_model_path, - signature_def_map, - quant_opts.tags, - init_op_name=exported_model.init_node_name, - saver_def=_get_saver_def_or_none(exported_model), - checkpoint_dir=exported_model.checkpoint_dir, - function_aliases=exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, - ) - - _copy_assets(src_saved_model_path, dst_saved_model_path) - - -def _get_min_max_from_calibrator( - node_id: bytes, - calib_opts: quant_opts_pb2.CalibrationOptions, -) -> tuple[float, float]: - """Calculate min and max from statistics using calibration options. - - Args: - node_id: bytes of node id. - calib_opts: Calibration options used for calculating min and max. - - Returns: - (min_value, max_value): Min and max calculated using calib_opts. - - Raises: - ValueError: Unsupported calibration method is given. - """ - statistics: calib_stats_pb2.CalibrationStatistics = ( - pywrap_quantize_model.get_statistics_from_calibrator(node_id) + quantization_options_serialized=quant_opts.SerializeToString(), + signature_keys=list(quant_opts.signature_keys), + signature_def_map_serialized=_serialize_signature_def_map( + signature_def_map + ), + function_aliases=dict(function_aliases), + py_function_library=py_function_lib.PyFunctionLibrary(), ) - min_value, max_value = calibration_algorithm.get_min_max_value( - statistics, calib_opts - ) - return min_value, max_value - - -def _add_calibration_statistics( - graph_def: graph_pb2.GraphDef, - calib_opts: quant_opts_pb2.CalibrationOptions, -) -> None: - """Adds calibration statistics to the graph def. - - This function must be run after running the graph with a representative - dataset. Retrieves calibration statistics from the global calibrator and adds - them to the corresponding nodes as attributes. - - Args: - graph_def: GraphDef to add calibration statistics to. - calib_opts: Calibration options to calculate min and max. - """ - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'CustomAggregator': - continue - - node_id = node_def.attr['id'].s - try: - min_value, max_value = _get_min_max_from_calibrator(node_id, calib_opts) - pywrap_quantize_model.clear_data_from_calibrator(node_id) - - node_def.attr['min'].f = min_value - node_def.attr['max'].f = max_value - except ValueError: - logging.warning( - ( - 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' - 'min or max values. Parts of this function are not quantized.' - ), - node_id.decode('utf-8'), - function_def.signature.name, - ) - - -def _enable_dump_tensor(graph_def: graph_pb2.GraphDef) -> None: - """Enable DumpTensor in the graph def. - - DumpTensor is disabled by default to avoid logging data during calibration. - This function is called after calibration to enable DumpTensor. - - Args: - graph_def: GraphDef to enable DumpTensor - """ - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'DumpTensor': - continue - - node_def.attr['enabled'].b = True - - -def _change_dump_tensor_file_name(graph_def: graph_pb2.GraphDef) -> None: - """Change file_name used by DumpTensor to quantized_tensor_data.pb. - - In whole model verify, DumpTensor in unquantized model uses file_name - unquantized_tensor_data.pb. - After unquantized dump model is created, this function allows quantized dump - model to use quantized_tensor_data.pb as file_name. - - Args: - graph_def: GraphDef to change file_name of DumpTensor - """ - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'DumpTensor': - continue - - node_def.attr['file_name'].s = 'quantized_tensor_data.pb'.encode('utf-8') - - -def _get_saver_def_or_none( - exported_model: exported_model_pb2.ExportedModel, -) -> Optional[saver_pb2.SaverDef]: - """Returns the SaverDef from ExportedModel, None otherwise. - - Args: - exported_model: ExportedModel to take the SaverDef from. - - Returns: - SaverDef instance if the field `saver_def` is set. None otherwise. - """ - if exported_model.HasField('saver_def'): - return exported_model.saver_def - return None def _run_static_range_ptq( @@ -766,134 +158,29 @@ def _run_static_range_ptq( Raises: ValueError if the graph doesn't contain a valid signature. """ - logging.info('Running post-training quantization pre-calibration step.') + logging.info('Running static-range post-training quantization.') loader = saved_model_loader.SavedModelLoader(src_saved_model_path) function_aliases = loader.get_meta_graph_def_from_tags( quant_opts.tags ).meta_info_def.function_aliases - exported_model_serialized = ( - pywrap_quantize_model.quantize_ptq_model_pre_calibration( - src_saved_model_path, - list(quant_opts.signature_keys), - set(quant_opts.tags), - quant_opts.SerializeToString(), - dict(function_aliases), - py_function_lib.PyFunctionLibrary(), - ) - ) - exported_model = exported_model_pb2.ExportedModel.FromString( - exported_model_serialized - ) - - graph_def = exported_model.graph_def - pre_calib_output_model_path = tempfile.mkdtemp() - save_model.save_model_v1( - graph_def, - pre_calib_output_model_path, - signature_def_map, - quant_opts.tags, - exported_model.init_node_name, - _get_saver_def_or_none(exported_model), - exported_model.checkpoint_dir, - exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, - ) - - _copy_assets(src_saved_model_path, pre_calib_output_model_path) - - # Uses the representative dataset to collect statistics for calibration. - # Handles the graph mode execution separately in case TF2 is disabled or - # eager execution is disabled. The min & max values are stored separately - # in a global CalibratorSingleton instance. - _run_graph_for_calibration( - pre_calib_output_model_path, - quant_opts.signature_keys, - quant_opts.tags, - representative_dataset, - quant_opts.force_graph_mode_calibration, - ) - - _add_calibration_statistics(graph_def, quant_opts.calibration_options) - - if quant_opts.HasField('debugger_options'): - # Since DumpTensor was disabled by default, we need to enable them. - _enable_dump_tensor(graph_def) - - if ( - quant_opts.debugger_options.debugger_type - == quant_opts_pb2.DebuggerOptions.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL - ): - # TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model - # TODO: b/296916287 - Create a separate function for saving unquantized - # dump model - save_model.save_model_v1( - graph_def, - quant_opts.debugger_options.unquantized_dump_model_path, - signature_def_map, - quant_opts.tags, - exported_model.init_node_name, - _get_saver_def_or_none(exported_model), - exported_model.checkpoint_dir, - exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, - ) - - _copy_assets( - src_saved_model_path, - quant_opts.debugger_options.unquantized_dump_model_path, - ) - - _change_dump_tensor_file_name(graph_def) - - calibrated_model_path = tempfile.mkdtemp() - save_model.save_model_v1( - graph_def, - calibrated_model_path, - signature_def_map, - quant_opts.tags, - exported_model.init_node_name, - _get_saver_def_or_none(exported_model), - exported_model.checkpoint_dir, - asset_file_defs=exported_model.asset_file_defs, - ) - - _copy_assets(pre_calib_output_model_path, calibrated_model_path) - - logging.info('Running post-training quantization post-calibration step.') - exported_model_serialized = ( - pywrap_quantize_model.quantize_ptq_model_post_calibration( - calibrated_model_path, - list(quant_opts.signature_keys), - set(quant_opts.tags), - quant_opts.SerializeToString(), - dict(exported_model.function_aliases), - ) - ) - - exported_model = exported_model_pb2.ExportedModel.FromString( - exported_model_serialized - ) - - save_model.save_model_v1( - exported_model.graph_def, + signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) + pywrap_quantize_model.quantize_ptq_static_range( + src_saved_model_path, dst_saved_model_path, - signature_def_map, - quant_opts.tags, - init_op_name=exported_model.init_node_name, - saver_def=_get_saver_def_or_none(exported_model), - checkpoint_dir=exported_model.checkpoint_dir, - function_aliases=exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, + quantization_options_serialized=quant_opts.SerializeToString(), + signature_keys=list(quant_opts.signature_keys), + signature_def_map_serialized=signature_def_map_serialized, + function_aliases=dict(function_aliases), + py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset=representative_dataset, ) - _copy_assets(calibrated_model_path, dst_saved_model_path) - def _static_range_quantize( - saved_model_path: str, - output_directory: str, + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options: _QuantizationOptions, representative_dataset: Optional[ repr_dataset.RepresentativeDatasetOrMapping @@ -907,10 +194,10 @@ def _static_range_quantize( model input, `representative_dataset` will be ignored. Args: - saved_model_path: Path to the saved model. When representative_dataset is - not provided, this should be a model trained with QAT. - output_directory: The path to save the output SavedModel. The directory will - be overwritten if not empty. + src_saved_model_path: Path to the saved model. When representative_dataset + is not provided, this should be a model trained with QAT. + dst_saved_model_path: The path to save the output SavedModel. The directory + will be overwritten if not empty. quantization_options: QuantizationOptions proto describing quantization related config. representative_dataset: a generator that returns a dictionary in {input_key: @@ -927,18 +214,18 @@ def _static_range_quantize( in the SavedModel. """ logging.info( - 'Running static range quantization on model: %s', saved_model_path + 'Running static range quantization on model: %s', src_saved_model_path ) logging.info('QuantizationOptions: \n%s', quantization_options) is_qat_saved_model_or_method_no_quantize = _is_qat_saved_model( - saved_model_path + src_saved_model_path ) or ( quantization_options.quantization_method.preset_method == _QuantizationMethod.METHOD_NO_QUANTIZE ) signature_def_map = save_model.get_signatures_from_saved_model( - saved_model_path, + src_saved_model_path, quantization_options.signature_keys, set(quantization_options.tags), ) @@ -961,34 +248,34 @@ def _static_range_quantize( if is_qat_saved_model_or_method_no_quantize: _run_static_range_qat( - saved_model_path, - output_directory, + src_saved_model_path, + dst_saved_model_path, quantization_options, signature_def_map, ) else: _run_static_range_ptq( - saved_model_path, - output_directory, + src_saved_model_path, + dst_saved_model_path, quantization_options, representative_dataset, signature_def_map, ) - return saved_model_load.load(output_directory) + return saved_model_load.load(dst_saved_model_path) def _dynamic_range_quantize( - saved_model_path: str, - output_directory: str, + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options: _QuantizationOptions, ) -> autotrackable.AutoTrackable: """Quantizes the given SavedModel via post-training dynamic range quantization. Args: - saved_model_path: Path to the saved model. - output_directory: The path to save the output SavedModel. The directory will - be overwritten if not empty. + src_saved_model_path: Path to the saved model. + dst_saved_model_path: The path to save the output SavedModel. The directory + will be overwritten if not empty. quantization_options: QuantizationOptions proto describing quantization related config. @@ -999,68 +286,56 @@ def _dynamic_range_quantize( ValueError: when the model is QAT model. """ mode_str = 'dynamic-range quantization' - if _is_qat_saved_model(saved_model_path): + if _is_qat_saved_model(src_saved_model_path): raise ValueError( 'The models trained with quantization-aware training (QAT) is not ' 'supported for %s.' % mode_str ) logging.info( - 'Running post-training %s on model: %s', mode_str, saved_model_path + 'Running post-training %s on model: %s', mode_str, src_saved_model_path ) logging.info('QuantizationOptions: \n%s', quantization_options) - loader = saved_model_loader.SavedModelLoader(saved_model_path) + loader = saved_model_loader.SavedModelLoader(src_saved_model_path) function_aliases = loader.get_meta_graph_def_from_tags( quantization_options.tags ).meta_info_def.function_aliases - # Apply post-training dynamic range quantization to the model. - exported_model_serialized = pywrap_quantize_model.quantize_ptq_dynamic_range( - saved_model_path, - list(quantization_options.signature_keys), - set(quantization_options.tags), - quantization_options.SerializeToString(), - dict(function_aliases), - ) - - exported_model = exported_model_pb2.ExportedModel.FromString( - exported_model_serialized - ) signature_def_map = save_model.get_signatures_from_saved_model( - saved_model_path, + src_saved_model_path, quantization_options.signature_keys, quantization_options.tags, ) - save_model.save_model_v1( - exported_model.graph_def, - output_directory, - signature_def_map, - quantization_options.tags, - init_op_name=exported_model.init_node_name, - saver_def=_get_saver_def_or_none(exported_model), - checkpoint_dir=exported_model.checkpoint_dir, - function_aliases=exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, + # Apply post-training dynamic range quantization to the model. + pywrap_quantize_model.quantize_ptq_dynamic_range( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=quantization_options.SerializeToString(), + signature_keys=list(quantization_options.signature_keys), + signature_def_map_serialized=_serialize_signature_def_map( + signature_def_map + ), + function_aliases=dict(function_aliases), + py_function_library=py_function_lib.PyFunctionLibrary(), ) - _copy_assets(saved_model_path, output_directory) - return saved_model_load.load(output_directory) + return saved_model_load.load(dst_saved_model_path) def _weight_only_quantize( - saved_model_path: str, - output_directory: str, + src_saved_model_path: str, + dst_saved_model_path: str, quantization_options: quant_opts_pb2.QuantizationOptions, ) -> autotrackable.AutoTrackable: """Quantizes the given SavedModel via weight-only quantization. Args: - saved_model_path: Path to the saved model. - output_directory: The path to save the output SavedModel. The directory will - be overwritten if not empty. + src_saved_model_path: Path to the saved model. + dst_saved_model_path: The path to save the output SavedModel. The directory + will be overwritten if not empty. quantization_options: QuantizationOptions proto describing quantization related config. @@ -1073,52 +348,41 @@ def _weight_only_quantize( mode_str = 'weight-only quantization' # QAT weight-only is not supported yet. - if _is_qat_saved_model(saved_model_path): + if _is_qat_saved_model(src_saved_model_path): raise ValueError( 'The models trained with quantization-aware training (QAT) is not ' 'supported for %s.' % mode_str ) logging.info( - 'Running post-training %s on model: %s', mode_str, saved_model_path + 'Running post-training %s on model: %s', mode_str, src_saved_model_path ) logging.info('QuantizationOptions: \n%s', quantization_options) - loader = saved_model_loader.SavedModelLoader(saved_model_path) + loader = saved_model_loader.SavedModelLoader(src_saved_model_path) function_aliases = loader.get_meta_graph_def_from_tags( quantization_options.tags ).meta_info_def.function_aliases - exported_model_serialized = pywrap_quantize_model.quantize_weight_only( - saved_model_path, - quantization_options.SerializeToString(), - dict(function_aliases), - ) - - exported_model = exported_model_pb2.ExportedModel.FromString( - exported_model_serialized - ) signature_def_map = save_model.get_signatures_from_saved_model( - saved_model_path, + src_saved_model_path, list(quantization_options.signature_keys), set(quantization_options.tags), ) - save_model.save_model_v1( - exported_model.graph_def, - output_directory, - signature_def_map, - quantization_options.tags, - init_op_name=exported_model.init_node_name, - saver_def=_get_saver_def_or_none(exported_model), - checkpoint_dir=exported_model.checkpoint_dir, - function_aliases=exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, + pywrap_quantize_model.quantize_weight_only( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=quantization_options.SerializeToString(), + signature_def_map_serialized=_serialize_signature_def_map( + signature_def_map + ), + function_aliases=dict(function_aliases), + py_function_library=py_function_lib.PyFunctionLibrary(), ) - _copy_assets(saved_model_path, output_directory) - return saved_model_load.load(output_directory) + return saved_model_load.load(dst_saved_model_path) def _verify_output_dir(output_dir: Optional[str], overwrite: bool) -> None: @@ -1356,28 +620,34 @@ def _populate_quantization_options_default_values( 'Legacy weight-only is deprecated. Use weight-only quantization method.' ) + # Converter assumes options are specified. So set SRQ explicitly. + if ( + quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_UNSPECIFIED + ): + logging.debug( + '"preset_method" for QuantizationMethod is not specified.' + 'Static range quantization is used by default.' + ) + quantization_options.quantization_method.preset_method = ( + _PresetMethod.METHOD_STATIC_RANGE_INT8 + ) + # Check default quantization option values for weight-only quantization. # TODO(b/242805842): Find good minimum_elements_for_weights number for server. # please also update default value in tflite converter: # tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc;l=201 - if ( - quantization_options.quantization_method.preset_method - == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 - ) or ( - quantization_options.quantization_method.preset_method - == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8 - ): - if quantization_options.min_num_elements_for_weights == 0: - quantization_options.min_num_elements_for_weights = ( - _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS - ) - logging.warning( - ( - 'QuantizationOptions.min_num_elements_for_weights is not set (0).' - ' Setting to the default value: %d.' - ), - _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS, - ) + if quantization_options.min_num_elements_for_weights == 0: + quantization_options.min_num_elements_for_weights = ( + _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS + ) + logging.warning( + ( + 'QuantizationOptions.min_num_elements_for_weights is not set (0).' + ' Setting to the default value: %d.' + ), + _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS, + ) # TODO: b/307900054 - Set the per-channel quantization by default. if quantization_options.enable_per_channel_quantization and not ( @@ -1417,19 +687,6 @@ def _populate_quantization_options_default_values( ' quantization via TF Quantizer.' ) - # Converter assumes options are specified. So set SRQ explicitly. - if ( - quantization_options.quantization_method.preset_method - == _PresetMethod.METHOD_UNSPECIFIED - ): - logging.debug( - '"preset_method" for QuantizationMethod is not specified.' - 'Static range quantization is used by default.' - ) - quantization_options.quantization_method.preset_method = ( - _PresetMethod.METHOD_STATIC_RANGE_INT8 - ) - if quantization_options.HasField('debugger_options'): # Set `force_graph_mode_calibration` to True to avoid skipping op execution, # which are not connected to return ops, during calibration execution. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py index f3e8cc9d6bcb50..6fc618b5f92646 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py @@ -18,7 +18,10 @@ import os from typing import Iterable, Mapping, Optional, Union +import numpy as np + from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import readers from tensorflow.python.eager import context @@ -302,3 +305,40 @@ def get_num_samples(repr_ds: RepresentativeDataset) -> Optional[int]: return None else: return None + + +def create_feed_dict_from_input_data( + input_data: RepresentativeSample, + signature_def: meta_graph_pb2.SignatureDef, +) -> Mapping[str, np.ndarray]: + """Constructs a feed_dict from input data. + + Note: This function should only be used in graph mode. + + This is a helper function that converts an 'input key -> input value' mapping + to a feed dict. A feed dict is an 'input tensor name -> input value' mapping + and can be directly passed to the `feed_dict` argument of `sess.run()`. + + Args: + input_data: Input key -> input value mapping. The input keys should match + the input keys of `signature_def`. + signature_def: A SignatureDef representing the function that `input_data` is + an input to. + + Returns: + Feed dict, which is intended to be used as input for `sess.run`. It is + essentially a mapping: input tensor name -> input value. Note that the input + value in the feed dict is not a `Tensor`. + """ + feed_dict = {} + for input_key, input_value in input_data.items(): + input_tensor_name = signature_def.inputs[input_key].name + + value = input_value + if isinstance(input_value, core.Tensor): + # Take the data out of the tensor. + value = input_value.eval() + + feed_dict[input_tensor_name] = value + + return feed_dict diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py index b5fa11c43bbc5b..f9e05be36eb5af 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py @@ -18,7 +18,9 @@ import numpy as np from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -224,6 +226,57 @@ def __len__(self): self.assertIsNone(repr_dataset.get_num_samples(LenRaisingError())) + @test_util.deprecated_graph_mode_only + def test_create_feed_dict_from_input_data(self): + signature_def = meta_graph_pb2.SignatureDef( + inputs={'input_tensor': meta_graph_pb2.TensorInfo(name='input:0')} + ) + rng = np.random.default_rng(seed=14) + + input_tensor_value = rng.random(size=(2, 2)) + sample = {'input_tensor': input_tensor_value} + + feed_dict = repr_dataset.create_feed_dict_from_input_data( + sample, signature_def + ) + + self.assertLen(feed_dict, 1) + self.assertIn('input:0', feed_dict) + self.assertAllEqual(feed_dict['input:0'], input_tensor_value) + + @test_util.deprecated_graph_mode_only + def test_create_feed_dict_from_input_data_core_tensors(self): + signature_def = meta_graph_pb2.SignatureDef( + inputs={'input_tensor': meta_graph_pb2.TensorInfo(name='input:0')} + ) + + with self.session(): + input_tensor = constant_op.constant([1, 2, 3, 4, 5, 6]) + sample = {'input_tensor': input_tensor} + + feed_dict = repr_dataset.create_feed_dict_from_input_data( + sample, signature_def + ) + input_tensor_data = input_tensor.eval() + + self.assertLen(feed_dict, 1) + self.assertIn('input:0', feed_dict) + self.assertIsInstance(feed_dict['input:0'], np.ndarray) + self.assertAllEqual(feed_dict['input:0'], input_tensor_data) + + @test_util.deprecated_graph_mode_only + def test_create_feed_dict_from_input_data_empty(self): + signature_def = meta_graph_pb2.SignatureDef( + inputs={'input_tensor': meta_graph_pb2.TensorInfo(name='input:0')} + ) + + sample = {} + feed_dict = repr_dataset.create_feed_dict_from_input_data( + sample, signature_def + ) + + self.assertEmpty(feed_dict) + class RepresentativeDatasetSaverTest(test.TestCase): """Test cases for RepresentativeDatasetSaver.""" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h index 669415a1aac078..a7beffd826a083 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h @@ -16,102 +16,131 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TYPE_CASTERS_H_ #include +#include #include +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace pybind11::detail { namespace internal { -// Serializes an ExportedModel. Raises python ValueError if serialization fails. -std::string Serialize( - const tensorflow::quantization::ExportedModel& exported_model) { - const std::string exported_model_serialized = - exported_model.SerializeAsString(); +// Serializes a protobuf object. Raises python ValueError if serialization +// fails. +inline std::string Serialize(const tsl::protobuf::Message& protobuf_object) { + const std::string serialized = protobuf_object.SerializeAsString(); // Empty string means it failed to serialize the protobuf with an error. See // the docstring for SerializeAsString for details. - if (exported_model_serialized.empty()) { - throw py::value_error("Failed to serialize ExportedModel."); + if (serialized.empty()) { + // Show the name of the protobuf message type to provide more information + // and easier debugging. + const std::string descriptor_name = + protobuf_object.GetDescriptor() == nullptr + ? "unknown" + : protobuf_object.GetDescriptor()->full_name(); + throw py::value_error(absl::StrFormat( + "Failed to serialize protobuf object: %s.", descriptor_name)); } - return exported_model_serialized; + return serialized; } -} // namespace internal - -// Handles `ExportedModel` (c++) <-> `bytes` (python) conversion. The `bytes` -// object in the python layer is a serialization of `ExportedModel`. +// Handles `ProtoT` (c++) <-> `bytes` (python) conversion. The `bytes` +// object in the python layer is a serialization of `ProtoT`. // -// See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html for -// further details on how custom type conversions work for pybind11. -template <> -struct type_caster { +// The caller of c++ interfaces should make sure to pass valid serialized +// `ProtoT` objects as arguments. Failing to do so results in raising a +// `ValueError`. Similarly, the python implementation of a c++ virtual member +// function that return an `ProtoT` should return a valid serialized `ProtoT`. +// +// See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html +template >> +struct SerializedProtobufCaster { public: - PYBIND11_TYPE_CASTER(tensorflow::quantization::ExportedModel, - const_name("ExportedModel")); + PYBIND11_TYPE_CASTER(ProtoT, const_name()); - // Loads an `ExportedModel` instance from a python `bytes` object (`src`). + // Loads an `ProtoT` instance from a python `bytes` object (`src`). bool load(handle src, const bool convert) { auto caster = make_caster(); // Make sure the user passed a valid python string. - if (!caster.load(src, convert)) { - return false; - } + if (!caster.load(src, convert)) return false; - const absl::string_view exported_model_serialized = + const absl::string_view serialized_proto = cast_op(std::move(caster)); // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(exported_model_serialized)); + return value.ParseFromString(std::string(serialized_proto)); } - // Constructs a `bytes` object after serializing `src`. - static handle cast(tensorflow::quantization::ExportedModel&& src, - return_value_policy policy, handle parent) { + // Constructs a `bytes` object by serializing `src`. + static handle cast(ProtoT&& src, return_value_policy policy, handle parent) { // release() prevents the reference count from decreasing upon the // destruction of py::bytes and returns a raw python object handle. - return py::bytes(internal::Serialize(src)).release(); + return py::bytes(Serialize(src)).release(); } - // Constructs a `bytes` object after serializing `src`. - static handle cast(const tensorflow::quantization::ExportedModel& src, - return_value_policy policy, handle parent) { + // Constructs a `bytes` object by serializing `src`. + static handle cast(const ProtoT& src, return_value_policy policy, + handle parent) { // release() prevents the reference count from decreasing upon the // destruction of py::bytes and returns a raw python object handle. - return py::bytes(internal::Serialize(src)).release(); + return py::bytes(Serialize(src)).release(); } }; -// Python -> cpp conversion for `QuantizationOptions`. Accepts a serialized -// protobuf string and deserializes into an instance of `QuantizationOptions`. +} // namespace internal + +// The following explicit specializations of protobuf `type_caster`s for +// specific protobuf message types are there to have higher priority over those +// defined in `native_proto_caster.h` during the resolution process. This is +// because the type casters in `native_proto_caster.h`, which allow seamlessly +// exchanging protobuf messages across c++-python boundaries, potentially +// without serialization, fail in the open-source environment. +// Explicitly-specialized type casters for serialized protobufs are added on an +// on-demand basis for quantization library. +// TODO: b/308532051 - Make `native_proto_caster.h` work in the open-source +// environment. + template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(tensorflow::quantization::QuantizationOptions, - const_name("QuantizationOptions")); +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::ExportedModel> {}; - bool load(handle src, const bool convert) { - auto caster = make_caster(); - // The user should have passed a valid python string. - if (!caster.load(src, convert)) { - return false; - } +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::QuantizationOptions> {}; - const absl::string_view quantization_opts_serialized = - cast_op(std::move(caster)); +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::CalibrationOptions> {}; - // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. - return value.ParseFromString(std::string(quantization_opts_serialized)); - } -}; +template <> +struct type_caster + : public internal::SerializedProtobufCaster {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::calibrator::CalibrationStatistics> {}; } // namespace pybind11::detail diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc new file mode 100644 index 00000000000000..b957ffe469a004 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/core/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace quantization { + +// Unfreezes constants into variables and saves them to a checkpoint files under +// `checkpoint_dir`. `checkpoint_dir` will be created within this function. It +// will return a non-OK status if it already exists or permission is denied. +// TODO(b/261652258): Make sure this works for when there are non-frozen +// variables in the model. +absl::Status UnfreezeConstantsAndSaveVariables( + const absl::string_view checkpoint_dir, mlir::MLIRContext &ctx, + mlir::ModuleOp module_op) { + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kTfQuantConstantUnfreezingStepName, /*add_passes_func=*/ + [](mlir::PassManager &pm) { + pm.addPass(mlir::quant::CreateUnfreezeConstantsPass()); + }, + ctx, module_op)); + + if (const tsl::Status create_dir_status = + Env::Default()->CreateDir(std::string(checkpoint_dir)); + !create_dir_status.ok()) { + LOG(ERROR) << "Failed to create checkpoint directory at: " + << checkpoint_dir; + return create_dir_status; + } + + TF_ASSIGN_OR_RETURN(const auto unused_variable_names, + SaveVariablesToCheckpoint(checkpoint_dir, module_op)); + + return RunPasses( + /*name=*/kTfQuantInsertRestoreOpStepName, + /*add_passes_func=*/ + [](mlir::PassManager &pm) { + pm.addPass(mlir::quant::CreateInsertRestoreOpPass()); + pm.addPass(mlir::quant::CreateInsertSaveOpPass()); + // Initialization by `tf.ConstOp` is no longer required as there is + // a `tf.RestoreV2Op` now. + pm.addPass( + mlir::quant::CreateRemoveVariableInitializationByConstPass()); + }, + ctx, module_op); +} +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h new file mode 100644 index 00000000000000..3086d705f315b7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_UNFREEZE_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_UNFREEZE_CONSTANTS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +inline constexpr absl::string_view kTfQuantConstantUnfreezingStepName = + "tf_quant_constant_unfreezing"; +inline constexpr absl::string_view kTfQuantInsertRestoreOpStepName = + "tf_quant_insert_restore_op"; + +absl::Status UnfreezeConstantsAndSaveVariables(absl::string_view checkpoint_dir, + mlir::MLIRContext &ctx, + mlir::ModuleOp module_op); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_UNFREEZE_CONSTANTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 0b9cdc09ca5b93..4825d316f6e691 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -36,7 +37,7 @@ using ::tensorflow::quantization::QuantizationOptions; // Currently server cannot handle UniformQuantizedTypes. Instead, unpack // quantized ops to primitive StableHLO ops. We currently go through a // StableHLO <-> MHLO roundtrip to utilize the MHLOQuantToInt pass. -void AddStablehloQuantToIntPasses(mlir::PassManager &pm) { +void AddStablehloQuantToIntPasses(mlir::OpPassManager &pm) { pm.addPass(mlir::createInlinerPass()); // StableHLO -> MHLO legalization. pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); @@ -50,32 +51,50 @@ void AddStablehloQuantToIntPasses(mlir::PassManager &pm) { } void AddStaticRangeQuantizationPass( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, std::optional mlir_dump_file_prefix) { pm.addPass(mlir::quant::stablehlo::createQuantizeCompositeFunctionsPass()); } -void AddConvertTpuToCpuModelPasses(mlir::PassManager &pm) { +void AddConvertTpuToCpuModelPasses(mlir::OpPassManager &pm) { pm.addPass(mlir::quant::CreateConvertTpuModelToCpuPass()); pm.addPass(mlir::createInlinerPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::quant::CreateCastBf16OpsToF32Pass()); } +// Legalizes shape/tensor/arith dialect ops to StableHLO for handling dynamic +// shapes, by going through a round-trip to MHLO. +void AddShapeLegalizationPasses(mlir::OpPassManager &pm) { + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass( + mlir::mhlo::createShapeLegalizeToHloPass(/*legalizeConstraints=*/true)); + // The following 2 passes are used to clean up the spurious UnrealizedCast ops + // and shape.assuming regions leftover from the ShapeLegalizeToHlo pass. See + // pass definition for details. + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); +} + // NOMUTANTS -- Add tests for individual passes with migration below. // Serializes the StableHLO module into a tf.XlaCallModuleOp for compatibility // with passes that expect TF format. This also allows the StableHLO ops to be // exported as a TF SavedModel. -void AddCallModuleSerializationPasses(mlir::PassManager &pm) { +void AddCallModuleSerializationPasses(mlir::OpPassManager &pm) { + AddShapeLegalizationPasses(pm); pm.addPass( mlir::quant::stablehlo:: createReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass()); + // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass may create + // duplicate constants. Add canonicalizer to deduplicate. + pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::TF::CreateXlaCallModuleSerializationPass()); } } // namespace void AddQuantizeQatPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix) { pm.addNestedPass( mlir::quant::CreateConvertFakeQuantToQdqPass()); @@ -123,7 +142,7 @@ void AddQuantizeQatPasses( } void AddQuantizePtqDynamicRangePasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix) { pm.addNestedPass( mlir::TF::CreateUnrollBatchMatMulPassPass()); @@ -167,7 +186,7 @@ void AddQuantizePtqDynamicRangePasses( } void AddQuantizePtqPreCalibrationPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options) { + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options) { if (quantization_options.op_set() == OpSet::UNIFORM_QUANTIZED) { pm.addNestedPass( mlir::TF::CreateUnrollBatchMatMulPassPass()); @@ -195,7 +214,7 @@ void AddQuantizePtqPreCalibrationPasses( } void AddQuantizePtqPostCalibrationPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); @@ -228,14 +247,12 @@ void AddQuantizePtqPostCalibrationPasses( } // StableHLO Quantization passes that are ran if StableHLO opset is selected. -// TODO: b/298581932 - Add tests for passes below once migration is complete. void AddQuantizePtqPreCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options) { + mlir::OpPassManager &pm, const CalibrationOptions &calibration_options) { pm.addPass( mlir::quant::stablehlo::createLiftQuantizableSpotsAsFunctionsPass()); pm.addNestedPass( - mlir::quant::CreateInsertCustomAggregationOpsPass( - quantization_options.calibration_options())); + mlir::quant::CreateInsertCustomAggregationOpsPass(calibration_options)); pm.addPass(mlir::quant::CreateIssueIDsOfCustomAggregationOpsPass()); // NOMUTANTS -- Add tests after all passes in function below are migrated. // StableHLO Quantizer currently uses TF's calibration passes. Serialize @@ -243,25 +260,29 @@ void AddQuantizePtqPreCalibrationStablehloPasses( AddCallModuleSerializationPasses(pm); } -// TODO: b/298581932 - Migrate and add passes below. void AddQuantizePtqPostCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, std::optional mlir_dump_file_prefix) { // Deserializes the StableHLO module embedded in tf.XlaCallModule and lifts // the StableHLO functions to the top level module. This is needed for // StableHLO quantization. + // + // Calibration may result in partial shape information loss. Add this pass to + // populate shape information based on the known information. + pm.addPass(mlir::quant::stablehlo::createPopulateShapePass()); pm.addPass(mlir::TF::CreateXlaCallModuleDeserializationPass()); pm.addPass(mlir::quant::stablehlo::createRestoreFunctionNamePass()); + pm.addPass(mlir::quant::stablehlo::createUnwrapXlaCallModuleOpPass()); + pm.addPass(mlir::createSymbolDCEPass()); pm.addNestedPass( mlir::quant::CreateConvertCustomAggregationOpToQuantStatsPass()); - AddStaticRangeQuantizationPass(pm, quantization_options, - mlir_dump_file_prefix); + AddStaticRangeQuantizationPass(pm, mlir_dump_file_prefix); AddStablehloQuantToIntPasses(pm); AddCallModuleSerializationPasses(pm); } void AddQuantizeWeightOnlyPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix) { pm.addPass(mlir::TF::CreateTFShapeInferencePass()); // Add PrepareLiftingPass to utilize its functionalities like folding batch diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h index 3aef23b5667d51..5d757b4c944441 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h @@ -26,35 +26,35 @@ namespace quantization { // mlir_dump_file_prefix is an optional field that is used for debugging to save // mlir dump files. -void AddQuantizeQatPasses(mlir::PassManager &pm, +void AddQuantizeQatPasses(mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix = std::nullopt); void AddQuantizePtqDynamicRangePasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix = std::nullopt); void AddQuantizeWeightOnlyPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix = std::nullopt); void AddQuantizePtqPreCalibrationPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options); + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options); void AddQuantizePtqPostCalibrationPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, std::optional mlir_dump_file_prefix = std::nullopt); // StableHLO Quantization passes that are ran if StableHLO opset is selected. void AddQuantizePtqPreCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options); + mlir::OpPassManager &pm, const CalibrationOptions &quantization_options); void AddQuantizePtqPostCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::OpPassManager &pm, std::optional mlir_dump_file_prefix = std::nullopt); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir index fa747357169f55..b8ed5d5f361d36 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_custom_aggregation_ops.mlir @@ -1,62 +1,60 @@ -// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=MIN_MAX' | FileCheck --check-prefix=MIN-MAX-CHECK %s -// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=AVERAGE_MIN_MAX' | FileCheck --check-prefix=AVERAGE-MIN-MAX-CHECK %s -// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_PERCENTILE' | FileCheck --check-prefix=HISTOGRAM-PERCENTILE-CHECK %s -// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_BRUTEFORCE' | FileCheck --check-prefix=HISTOGRAM-MSE-BRUTEFORCE-CHECK %s -// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_MAX_FREQUENCY' | FileCheck --check-prefix=HISTOGRAM-MSE-MAX-FREQUENCY-CHECK %s -// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_SYMMETRIC' | FileCheck --check-prefix=HISTOGRAM-MSE-SYMMETRIC-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=MIN_MAX' -split-input-file | FileCheck --check-prefix=MIN-MAX-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=AVERAGE_MIN_MAX' -split-input-file | FileCheck --check-prefix=AVERAGE-MIN-MAX-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_PERCENTILE' -split-input-file | FileCheck --check-prefix=HISTOGRAM-PERCENTILE-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_BRUTEFORCE' -split-input-file | FileCheck --check-prefix=HISTOGRAM-MSE-BRUTEFORCE-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_MAX_FREQUENCY' -split-input-file | FileCheck --check-prefix=HISTOGRAM-MSE-MAX-FREQUENCY-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_SYMMETRIC' -split-input-file | FileCheck --check-prefix=HISTOGRAM-MSE-SYMMETRIC-CHECK %s module { - func.func @add_custom_ops(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %add = "tf.AddV2"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %add : tensor<*xf32> + func.func @wrap_composite_func(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.PartitionedCall"(%arg0, %arg1) <{f = @composite_conv2d_with_relu6_fn}> {_tfl_quant_trait = "fully_quantizable"} + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> } - func.func @no_custom_ops_on_non_f32_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - %add = "tf.AddV2"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - func.return %add : tensor<*xi32> + func.func @no_composite_func(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %add : tensor<*xf32> } - func.func @composite_conv2d_with_bias_and_relu6_fn(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + func.func @composite_conv2d_with_relu6_fn(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> - %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> + %1 = "tf.Relu6"(%0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> } } // CalibrationOptions(calibration_method=CALIBRATION_METHOD_MIN_MAX) -// MIN-MAX-CHECK: func @add_custom_ops +// MIN-MAX-CHECK: func @wrap_composite_func // MIN-MAX-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // MIN-MAX-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) // MIN-MAX-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> -// MIN-MAX-CHECK: func @no_custom_ops_on_non_f32_type +// MIN-MAX-CHECK: func @no_composite_func // MIN-MAX-CHECK-NEXT: "tf.AddV2" // MIN-MAX-CHECK-NEXT: return -// MIN-MAX-CHECK: func @composite_conv2d_with_bias_and_relu6_fn +// MIN-MAX-CHECK: func @composite_conv2d_with_relu6_fn // MIN-MAX-CHECK-NEXT: "tf.Conv2D" -// MIN-MAX-CHECK-NEXT: "tf.BiasAdd" // MIN-MAX-CHECK-NEXT: "tf.Relu6" // MIN-MAX-CHECK-NEXT: return // CalibrationOptions(calibration_method=CALIBRATION_METHOD_AVERAGE_MIN_MAX) -// AVERAGE-MIN-MAX-CHECK: func @add_custom_ops +// AVERAGE-MIN-MAX-CHECK: func @wrap_composite_func // AVERAGE-MIN-MAX-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // AVERAGE-MIN-MAX-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// AVERAGE-MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// AVERAGE-MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) // AVERAGE-MIN-MAX-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 2 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // AVERAGE-MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> -// AVERAGE-MIN-MAX-CHECK: func @no_custom_ops_on_non_f32_type +// AVERAGE-MIN-MAX-CHECK: func @no_composite_func // AVERAGE-MIN-MAX-CHECK-NEXT: "tf.AddV2" // AVERAGE-MIN-MAX-CHECK-NEXT: return -// AVERAGE-MIN-MAX-CHECK: func @composite_conv2d_with_bias_and_relu6_fn +// AVERAGE-MIN-MAX-CHECK: func @composite_conv2d_with_relu6_fn // AVERAGE-MIN-MAX-CHECK-NEXT: "tf.Conv2D" -// AVERAGE-MIN-MAX-CHECK-NEXT: "tf.BiasAdd" // AVERAGE-MIN-MAX-CHECK-NEXT: "tf.Relu6" // AVERAGE-MIN-MAX-CHECK-NEXT: return @@ -64,20 +62,19 @@ module { // calibration_method=CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, // calibration_parameters=CalibrationParameters(initial_num_bins=256, min_percentile=0.001, max_percentile=99.999) // ) -// HISTOGRAM-PERCENTILE-CHECK: func @add_custom_ops +// HISTOGRAM-PERCENTILE-CHECK: func @wrap_composite_func // HISTOGRAM-PERCENTILE-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-PERCENTILE-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) // HISTOGRAM-PERCENTILE-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 3 : i32, initial_num_bins = 256 : i32, max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-PERCENTILE-CHECK-NEXT: return [[res]] : tensor<*xf32> -// HISTOGRAM-PERCENTILE-CHECK: func @no_custom_ops_on_non_f32_type +// HISTOGRAM-PERCENTILE-CHECK: func @no_composite_func // HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.AddV2" // HISTOGRAM-PERCENTILE-CHECK-NEXT: return -// HISTOGRAM-PERCENTILE-CHECK: func @composite_conv2d_with_bias_and_relu6_fn +// HISTOGRAM-PERCENTILE-CHECK: func @composite_conv2d_with_relu6_fn // HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.Conv2D" -// HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.BiasAdd" // HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.Relu6" // HISTOGRAM-PERCENTILE-CHECK-NEXT: return @@ -85,20 +82,19 @@ module { // calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, // calibration_parameters=CalibrationParameters(initial_num_bins=256) // ) -// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @add_custom_ops +// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @wrap_composite_func // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 4 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return [[res]] : tensor<*xf32> -// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @no_custom_ops_on_non_f32_type +// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @no_composite_func // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.AddV2" // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return -// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @composite_conv2d_with_bias_and_relu6_fn +// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @composite_conv2d_with_relu6_fn // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.Conv2D" -// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.BiasAdd" // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.Relu6" // HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return @@ -106,20 +102,19 @@ module { // calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, // calibration_parameters=CalibrationParameters(initial_num_bins=256) // ) -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @add_custom_ops +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @wrap_composite_func // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 5 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return [[res]] : tensor<*xf32> -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @no_custom_ops_on_non_f32_type +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @no_composite_func // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.AddV2" // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @composite_conv2d_with_bias_and_relu6_fn +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @composite_conv2d_with_relu6_fn // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.Conv2D" -// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.BiasAdd" // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.Relu6" // HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return @@ -127,20 +122,56 @@ module { // calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, // calibration_parameters=CalibrationParameters(initial_num_bins=256) // ) -// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @add_custom_ops +// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @wrap_composite_func // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[rhs:%.*]] = "tf.CustomAggregator"(%arg1) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[lhs:%.*]] = "tf.CustomAggregator"(%arg0) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[add:%.*]] = "tf.AddV2"([[lhs]], [[rhs]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[res:%.*]] = "tf.CustomAggregator"([[add]]) <{id = ""}> {calibration_method = 6 : i32, initial_num_bins = 256 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return [[res]] : tensor<*xf32> -// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @no_custom_ops_on_non_f32_type +// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @no_composite_func // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.AddV2" // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return -// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @composite_conv2d_with_bias_and_relu6_fn +// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @composite_conv2d_with_relu6_fn // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.Conv2D" -// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.BiasAdd" // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.Relu6" // HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return + +// ----- + +module { + // CHECK-LABEL: func.func @main + func.func @main(%arg0: tensor, %arg1: tensor<100352x10xf32>) -> tensor { + // CHECK-DAG: %[[ARG0_ID:.*]] = "tf.Identity"(%arg0) + // CHECK-DAG: %[[ARG1_ID:.*]] = "tf.Identity"(%arg1) + // CHECK-DAG: %[[ARG0_AGG:.*]] = "tf.CustomAggregator"(%[[ARG0_ID]]) + // CHECK-DAG: %[[ARG1_AGG:.*]] = "tf.CustomAggregator"(%[[ARG1_ID]]) + // CHECK: %[[RES:.*]] = "tf.XlaCallModule"(%[[ARG0_AGG]], %[[ARG1_AGG]]) + // CHECK: %[[RES_AGG:.*]] = "tf.CustomAggregator"(%[[RES]]) + // CHECK-DAG: %[[RES_ID:.*]] = "tf.Identity"(%[[RES_AGG]]) + // CHECK: return %[[RES_ID]] : tensor + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%arg1) {device = ""} : (tensor<100352x10xf32>) -> tensor<100352x10xf32> + %2 = "tf.XlaCallModule"(%0, %1) <{ + Sout = [#tf_type.shape], dim_args_spec = [], + disabled_checks = [], function_list = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + }> { + _entry_function = @composite_dot_general_fn_1, + _original_entry_function = "composite_dot_general_fn_1", + _tfl_quant_trait = "fully_quantizable" + } : (tensor, tensor<100352x10xf32>) -> tensor + %3 = "tf.Identity"(%2) {device = ""} : (tensor) -> tensor + return %3 : tensor + } + + // CHECK-LABEL: func.func private @composite_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<100352x10xf32>) -> tensor { + // CHECK-NOT: tf.CustomAggregator + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<100352x10xf32>) -> tensor + return %0 : tensor + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index e3b67f59a5e829..e7b42fcd09e2aa 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -1,4 +1,3 @@ -load("@llvm-project//mlir:tblgen.bzl", "td_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") @@ -34,39 +33,6 @@ cc_library( ], ) -td_library( - name = "lift_as_function_call_utils_td_files", - srcs = [ - "lift_as_function_call_utils.td", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - "@llvm-project//mlir:FuncTdFiles", - ], -) - -cc_library( - name = "lift_as_function_call_utils", - srcs = ["lift_as_function_call_utils.cc"], - hdrs = ["lift_as_function_call_utils.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", - "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", - "//tensorflow/core:framework_lite", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "tf_to_uniform_attribute_utils", srcs = ["tf_to_uniform_attribute_utils.cc"], @@ -74,7 +40,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/ops:uniform_op_quant_spec", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -112,7 +78,7 @@ cc_library( hdrs = ["tf_to_xla_attribute_utils.h"], compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", "//tensorflow/lite/kernels:padding", "@com_google_absl//absl/algorithm:container", @@ -129,7 +95,7 @@ tf_cc_test( deps = [ ":tf_to_xla_attribute_utils", "//tensorflow/c/eager:c_api", - "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h index 35a00db79e368f..922729d9c8c3a6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/StringMap.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" namespace mlir::quant { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc index 168c99a7b2cf86..f1d7a6ae576c7b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "llvm/ADT/ArrayRef.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "xla/xla_data.pb.h" #include "tensorflow/lite/kernels/padding.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h index 52dcdcbc780325..80212b9acec5fb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h @@ -19,7 +19,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_XLA_ATTRIBUTE_UTILS_H_ #include "mlir/IR/Builders.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" namespace mlir::quant { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc index b71ccae8f7c0a0..cc4bbb344026da 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" namespace mlir::quant { diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index f1265510044fdd..c16a7118cecce0 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -53,5 +53,6 @@ py_strict_test( python_version = "PY3", deps = [ ":stablehlo", + #internal proto upb dep ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a24ea5e7a8fe63..77530c113b9be2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -353,6 +353,7 @@ cc_library( ":attribute_utils", ":convert_type", ":dynamic_shape_utils", + ":side_effect_analysis_util", ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_op_interfaces", @@ -407,6 +408,7 @@ cc_library( deps = [ ":attribute_utils", ":serialize_mlir_module_utils", + ":side_effect_analysis_util", ":tensorflow_attributes", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", @@ -451,6 +453,7 @@ cc_library( "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ + ":side_effect_analysis_util", ":tensorflow_attributes", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", @@ -492,6 +495,7 @@ cc_library( "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], deps = [ + ":side_effect_analysis_util", ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_op_interfaces", @@ -760,133 +764,6 @@ cc_library( ], ) -cc_library( - name = "upgrade_graph", - srcs = ["translate/upgrade_graph.cc"], - hdrs = ["translate/upgrade_graph.h"], - deps = [ - ":attribute_utils", - "//tensorflow/compiler/tf2xla:functionalize_control_flow", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core/common_runtime:device", - "//tensorflow/core/common_runtime:device_factory", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:grappler_item_builder", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - "//tensorflow/core/protobuf:for_core_protos_cc", - "@llvm-project//llvm:Support", - ], -) - -cc_library( - name = "export_graphdef", - srcs = [ - "translate/export_graphdef.cc", - ], - hdrs = [ - "translate/export_graphdef.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":convert_type", - ":error_util", - ":export_tf_dialect_op", - ":export_utils", - ":mlir_roundtrip_flags", - ":tensorflow", - ":translate_utils", - ":verify_suitable_for_graph_export", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/utils:name_utils", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/graph/regularization:util", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "import_model", - srcs = [ - "translate/import_model.cc", - ], - hdrs = [ - "translate/export_graphdef.h", - "translate/import_model.h", - ], - deps = [ - ":attribute_utils", - ":convert_attr", - ":convert_tensor", - ":convert_type", - ":dump_mlir_util", - ":dynamic_shape_utils", - ":error_util", - ":mangling_util", - ":mlir_import_options", - ":mlir_roundtrip_flags", - ":tensorflow", - ":tensorflow_attributes", - ":tensorflow_types", - ":translate_utils", - ":upgrade_graph", - "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/cc/saved_model:constants", - "//tensorflow/cc/saved_model:loader_lite", - "//tensorflow/cc/saved_model:loader_util", - "//tensorflow/compiler/jit:shape_inference_helpers", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", - "//tensorflow/compiler/mlir/tensorflow/transforms:initialize_variables_in_session_init", - "//tensorflow/compiler/mlir/tensorflow/transforms:lift_variables_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:mark_initialized_variables_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/utils:transitive_fanin", - "//tensorflow/core/platform:crash_analysis", - "//tensorflow/core/platform:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_xla//xla:status_macros", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_parser", - ], -) - cc_library( name = "parse_text_proto", srcs = ["utils/parse_text_proto.cc"], @@ -912,20 +789,6 @@ cc_library( ], ) -tf_cc_test( - name = "tf_mlir_translate_registration_test", - size = "small", - srcs = ["translate/tf_mlir_translate_registration_test.cc"], - deps = [ - ":translate_registration", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:TranslateLib", - ], -) - cc_library( name = "export_utils", srcs = [ @@ -1002,92 +865,6 @@ cc_library( ], ) -cc_library( - name = "export_tf_dialect_op", - srcs = [ - "translate/export_tf_dialect_op.cc", - ], - hdrs = [ - "translate/export_tf_dialect_op.h", - ], - deps = [ - ":convert_type", - ":export_utils", - ":tensorflow", - "//tensorflow/compiler/mlir/utils:string_container_utils", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:DerivedAttributeOpInterface", - "@llvm-project//mlir:IR", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "translate_tf_dialect_op", - srcs = ["translate/translate_tf_dialect_op.cc"], - deps = [ - ":export_tf_dialect_op", - ":tensorflow", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - ], - alwayslink = 1, -) - -cc_library( - name = "mlir_roundtrip_pass", - srcs = ["translate/mlir_roundtrip_pass.cc"], - hdrs = ["translate/mlir_roundtrip_pass.h"], - deps = [ - ":error_util", - ":export_graphdef", - ":import_model", - ":mlir_roundtrip_flags", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@local_xla//xla:status_macros", - ], -) - -cc_library( - name = "mlir_roundtrip_pass_registration", - srcs = ["translate/mlir_roundtrip_pass_registration.cc"], - deps = [ - ":mlir_roundtrip_pass", - ], - alwayslink = 1, -) - -cc_library( - name = "mlir_roundtrip_flags", - srcs = ["translate/mlir_roundtrip_flags.cc"], - hdrs = ["translate/mlir_roundtrip_flags.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@local_xla//xla:status_macros", - ], -) - cc_library( name = "convert_attr", srcs = ["utils/convert_attr.cc"], @@ -1249,90 +1026,6 @@ cc_library( ], ) -cc_library( - name = "mlir_import_options", - hdrs = ["translate/mlir_import_options.h"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "translate_lib", - srcs = ["translate/tf_mlir_translate.cc"], - hdrs = ["translate/tf_mlir_translate.h"], - visibility = ["//visibility:public"], - deps = [ - ":error_util", - ":import_model", - ":import_utils", - ":mangling_util", - ":mlir_import_options", - ":mlir_roundtrip_flags", - "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/cc/saved_model:loader_lite", - "//tensorflow/cc/saved_model:reader", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/utils:transitive_fanin", - "//tensorflow/core/util/tensor_bundle:byteswaptensor", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - ], -) - -cc_library( - name = "translate_cl_options", - srcs = [ - "translate/tf_mlir_translate_cl.cc", - ], - hdrs = [ - "translate/tf_mlir_translate_cl.h", - ], - deps = [ - "@llvm-project//llvm:Support", - ], - alwayslink = 1, -) - -cc_library( - name = "translate_registration", - srcs = [ - "translate/tf_mlir_translate_registration.cc", - ], - deps = [ - ":export_graphdef", - ":mlir_roundtrip_flags", - ":tensorflow", - ":translate_cl_options", - ":translate_lib", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/service/cpu:cpu_compiler", - "@local_xla//xla/service/cpu:cpu_transfer_manager", - "@local_xla//xla/stream_executor", - "@local_xla//xla/stream_executor/host:host_platform", - "@local_xla//xla/stream_executor/host:host_platform_id", - ], - alwayslink = 1, -) - tf_cc_test( name = "error_util_test", srcs = ["utils/error_util_test.cc"], @@ -1488,6 +1181,7 @@ cc_library( ":device_util", ":tensorflow", ":tensorflow_types", + "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/utils:string_container_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -1513,6 +1207,7 @@ tf_cc_test( ":serialize_mlir_module_utils", ":tensorflow", ":tpu_rewrite_device_util", + "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -1937,27 +1632,6 @@ cc_library( ], ) -cc_library( - name = "split_into_island_per_op_pass", - srcs = ["translate/split_into_island_per_op_pass.cc"], - hdrs = [ - "ir/tf_executor.h", - "translate/split_into_island_per_op_pass.h", - ], - deps = [ - ":tensorflow", - ":tensorflow_executor_inc_gen", - ":tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:Pass", - ], -) - tf_cc_test( name = "xla_rewrite_util_test", size = "small", @@ -1968,6 +1642,7 @@ tf_cc_test( ":tensorflow", ":tpu_rewrite_device_util", ":xla_rewrite_util", + "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -1978,6 +1653,22 @@ tf_cc_test( ], ) +cc_library( + name = "side_effect_analysis_util", + srcs = [ + "utils/side_effect_analysis_util.cc", + ], + hdrs = [ + "utils/side_effect_analysis_util.h", + ], + deps = [ + "tensorflow_side_effects", + "tensorflow_types", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + build_test( name = "tensorflow_build_test", targets = [ @@ -2008,3 +1699,30 @@ build_test( # ) # # copybara:uncomment_end(google-only) + +# Required as we created the transforms subpackage and need to update +# these BUILD targets in a follow up. +aliased_targets = [ + "export_graphdef", + "import_model", + "export_tf_dialect_op", + "translate_tf_dialect_op", + "mlir_roundtrip_pass", + "mlir_roundtrip_pass_registration", + "mlir_roundtrip_flags", + "mlir_import_options", + "translate_lib", + "translate_cl_options", + "translate_registration", + "split_into_island_per_op_pass", + "upgrade_graph", +] + +[ + alias( + name = target, + actual = "//tensorflow/compiler/mlir/tensorflow/translate:%s" % target, + visibility = ["//visibility:public"], + ) + for target in aliased_targets +] diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 744fa37a914de0..b0d730898316d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -346,6 +347,7 @@ class OpSideEffectCollector { } bool IsCallToPureFunction(Operation* callOp) const; + bool IsPureFunction(func::FuncOp func_op) const; private: // Adds op-based side effects from all ops in `region` to `op` side effects. @@ -510,18 +512,42 @@ bool OpSideEffectCollector::IsCallToPureFunction(Operation* callOp) const { return false; // not a call func::FuncOp func_op = dyn_cast(call.resolveCallable( &symbol_table_collection_)); + return IsPureFunction(func_op); +} + +bool OpSideEffectCollector::IsPureFunction(func::FuncOp func_op) const { auto it = is_pure_function_.find(func_op); if (it == is_pure_function_.end()) { bool is_pure = true; + is_pure_function_[func_op] = is_pure; // prevent infinite recursion func_op->walk([&](Operation* op) { - if (op == func_op) return WalkResult::advance(); + if (op == func_op) { + return WalkResult::advance(); + } + // AssertOp is not, technically, pure. However, we treat functions + // that contain an assert as pure, so that graphs with and without + // assert don't have different side effect semantics. Also see + // b/309824992 for the challenges associated with improving the side + // effect modelling of Assert on the op level. + if (llvm::isa(op)) { + return WalkResult::advance(); + } + if (auto if_op = llvm::dyn_cast(op)) { + if (IsPureFunction(if_op.then_function()) && + IsPureFunction(if_op.else_function())) { + return WalkResult::advance(); + } + } + if (IsCallToPureFunction(op)) { + return WalkResult::advance(); + } if (TensorFlowDialect::CanHaveSideEffects(op)) { is_pure = false; return WalkResult::interrupt(); } return WalkResult::advance(); }); - is_pure_function_.insert({func_op, is_pure}); + is_pure_function_[func_op] = is_pure; } return is_pure_function_[func_op]; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index 95d766359b1d05..a3c95bdf2332a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -52,11 +52,13 @@ class TensorFlowExecutorDialect : public Dialect { class ControlType : public Type::TypeBase { public: using Base::Base; + static constexpr ::mlir::StringLiteral name = "tf_executor.control"; }; class TokenType : public Type::TypeBase { public: using Base::Base; + static constexpr ::mlir::StringLiteral name = "tf_executor.token"; }; } // namespace tf_executor diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index f69d3d4f9c97f4..553794ecd25b90 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1198,7 +1198,9 @@ It is computed as: Arg, [{2-D or higher with shape `[..., r_y, c_y]`.}]>:$y, DefaultValuedOptionalAttr:$adj_x, - DefaultValuedOptionalAttr:$adj_y + DefaultValuedOptionalAttr:$adj_y, + DefaultValuedOptionalAttr:$grad_x, + DefaultValuedOptionalAttr:$grad_y ); let results = (outs @@ -1245,7 +1247,9 @@ about broadcasting Arg, [{2-D or higher with shape `[..., r_y, c_y]`.}]>:$y, DefaultValuedOptionalAttr:$adj_x, - DefaultValuedOptionalAttr:$adj_y + DefaultValuedOptionalAttr:$adj_y, + DefaultValuedOptionalAttr:$grad_x, + DefaultValuedOptionalAttr:$grad_y ); let results = (outs @@ -1292,7 +1296,9 @@ about broadcasting Arg, [{2-D or higher with shape `[..., r_y, c_y]`.}]>:$y, DefaultValuedOptionalAttr:$adj_x, - DefaultValuedOptionalAttr:$adj_y + DefaultValuedOptionalAttr:$adj_y, + DefaultValuedOptionalAttr:$grad_x, + DefaultValuedOptionalAttr:$grad_y ); let results = (outs @@ -2095,7 +2101,7 @@ def TF_CeilOp : TF_Op<"Ceil", [Pure, TF_Idempotent, TF_SameOperandsAndResultType }]; } -def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [TF_SameOperandsAndResultTypeResolveRef]> { +def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, TF_SameOperandsAndResultTypeResolveRef]> { let summary = "Checks a tensor for NaN and Inf values."; let description = [{ @@ -8615,7 +8621,9 @@ cublas. TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$b, DefaultValuedOptionalAttr:$transpose_a, - DefaultValuedOptionalAttr:$transpose_b + DefaultValuedOptionalAttr:$transpose_b, + DefaultValuedOptionalAttr:$grad_a, + DefaultValuedOptionalAttr:$grad_b ); let results = (outs @@ -22008,7 +22016,7 @@ a u64[2] and for PHILOX a u64[3].}]>:$initial_state, let results = (outs TF_Uint64Tensor:$output_key, - TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + TensorOf<[TF_Int32, TF_Int64, TF_Int8, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<2>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index cee0e40f9cfeb5..8b8b069ea6f40d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -67,6 +67,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -85,6 +86,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" @@ -1090,6 +1092,24 @@ OpFoldResult CastOp::fold(FoldAdaptor) { return {}; } +//===----------------------------------------------------------------------===// +// CheckNumericsOp +//===----------------------------------------------------------------------===// + +void CheckNumericsOp::getEffects( + SmallVectorImpl>& + effects) { + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::CheckNumerics::get()); + MarkResourceAsReadOnly(getTensor(), effects); +} + +// For `CheckNumerics` ops the `device` attribute corresponds to the resource +// instance. +std::optional CheckNumericsOp::GetResourceInstanceStr() { + return GetDeviceAttrAsResourceInstanceStr(*this); +} + //===----------------------------------------------------------------------===// // CollectiveReduceV2Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 01cbbb9a46967c..122677ee4ad6da 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -88,6 +88,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h" namespace mlir { namespace TF { @@ -2344,21 +2345,13 @@ void TPUExecuteOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::TPUExecute::get()); + // Conservatively mark resource handles as read and write, as without + // analyzing TPUCompile, there is not sufficient information to determine + // effects on resources. For the MLIR bridge, this op will never be + // populated with resource handles and tf.TPUExecuteAndUpdateVariables is + // used instead. for (Value value : getArgs()) { - if (value.getType() - .cast() - .getElementType() - .isa()) { - // Conservatively mark resource handles as read and write, as without - // analyzing TPUCompile, there is not sufficient information to determine - // effects on resources. For the MLIR bridge, this op will never be - // populated with resource handles and tf.TPUExecuteAndUpdateVariables is - // used instead. - effects.emplace_back(MemoryEffects::Read::get(), value, - ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - ResourceEffects::Variable::get()); - } + MarkResourceAsReadAndWrite(value, effects); } } @@ -2373,19 +2366,11 @@ void _XlaRunOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::_XlaRun::get()); + // Conservatively mark resource handles as read and write, as without + // analyzing _XlaCompile, there is not sufficient information to determine + // effects on resources. for (Value value : getArgs()) { - if (value.getType() - .cast() - .getElementType() - .isa()) { - // Conservatively mark resource handles as read and write, as without - // analyzing _XlaCompile, there is not sufficient information to determine - // effects on resources. - effects.emplace_back(MemoryEffects::Read::get(), value, - ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - ResourceEffects::Variable::get()); - } + MarkResourceAsReadAndWrite(value, effects); } } @@ -3059,35 +3044,22 @@ LogicalResult XlaCallModuleOp::verifySymbolUses( void XlaLaunchOp::getEffects( SmallVectorImpl> &effects) { - effects.reserve(getArgs().size() + 1); + effects.reserve(2 * getArgs().size() + 1); effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::XlaLaunch::get()); + // Conservatively mark resource handles as read and write, as without + // analyzing XlaLaunch, there is not sufficient information to determine + // effects on resources. for (Value value : getArgs()) { - if (value.getType() - .cast() - .getElementType() - .isa()) { - // Conservatively mark resource handles as read and write, as without - // analyzing XlaLaunch, there is not sufficient information to determine - // effects on resources. - effects.emplace_back(MemoryEffects::Read::get(), value, - ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - ResourceEffects::Variable::get()); - } + MarkResourceAsReadAndWrite(value, effects); } } // For `XlaLaunch` ops the `device` attribute corresponds to the resource // instance. std::optional XlaLaunchOp::GetResourceInstanceStr() { - auto device_attr = (*this)->getAttrOfType("device"); - // Treat missing device attribute like unspecified (= empty string) attribute. - // Note that different op instances with the same string (including empty - // string) are seen as dependent (same resource instance). - if (!device_attr) return ""; - return device_attr.str(); + return GetDeviceAttrAsResourceInstanceStr(*this); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 6384d0770a3358..9bcc75fbe1e424 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -121,6 +121,11 @@ struct _XlaRun : public ::mlir::SideEffects::Resource::Base<_XlaRun> { StringRef getName() final { return "_XlaRun"; } }; +struct CheckNumerics + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "CheckNumerics"; } +}; + // Returns true iff resource type with given ID is only self-dependent, i.e., // there are no dependencies to other resource types (including unknown resource // type). diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 612f01ce23ce8a..13cdaa1a445842 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -27,7 +27,7 @@ func.func @testGatherToV2(%params: tensor<4x3xf32>, %indices: tensor<1x2xi32>) - // CHECK-LABEL: testBatchMatMulToV2 func.func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> { - // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} + // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} %0 = "tf.BatchMatMul"(%arg0, %arg1) <{adj_x = false, adj_y = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> func.return %0: tensor<2x3x7xf32> } @@ -41,7 +41,7 @@ func.func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> { - // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32> + // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32> %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32> // CHECK: return %0 func.return %0: tensor<2x2xf32> @@ -49,7 +49,7 @@ func.func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32 // CHECK-LABEL: testBatchMatMulV2ToMatMul func.func @testBatchMatMulV2ToMatMul(%arg0: tensor<4x3xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { - // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> + // CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = false}> {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = false, device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> // CHECK: return %0 func.return %0: tensor<3x5xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir index c5cf58971296fb..4285b57a322217 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='builtin.module(tf-functional-control-flow-to-regions{allow-passthrough-args})' -split-input-file | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(tf-functional-control-flow-to-regions{allow-passthrough-args})' -split-input-file -verify-diagnostics | FileCheck %s // Simple If // CHECK: func private @testIf1Then{{.+}} @@ -298,3 +298,73 @@ func.func @testWhileDevice() { // CHECK: device = "/device:CPU:0" func.return } + +// ----- + +// CHECK-LABEL: func @init +func.func @init(%arg0: tensor<4xf32>) -> tensor<7xf32> { + %0 = builtin.unrealized_conversion_cast to tensor<7xf32> + return %0 : tensor<7xf32> +} + +// CHECK-LABEL: func @next +func.func @next(%arg0: tensor<7xf32>, %arg1: tensor<3xf32>) -> tensor<6xf32> { + %0 = builtin.unrealized_conversion_cast to tensor<6xf32> + return %0 : tensor<6xf32> +} + +// CHECK-LABEL: func @finalize +func.func @finalize(%arg0: tensor<6xf32>, %arg1: tensor<2xf32>) -> tensor<5xf32> { + %0 = builtin.unrealized_conversion_cast to tensor<5xf32> + return %0 : tensor<5xf32> +} + +// CHECK-LABEL: func @testGeneratorDataset +func.func @testGeneratorDataset(%arg0: tensor<4xf32>, + %arg1: tensor<3xf32>, + %arg2: tensor, + %arg3: tensor<2xf32>) { + // CHECK-NOT: tf.GeneratorDataset + // CHECK: tf.GeneratorDatasetRegion + // CHECK: ^ + // CHECK-SAME: tensor<4xf32> + // CHECK: func.call @init + // CHECK: ^ + // CHECK-SAME: tensor<7xf32> + // CHECK-SAME: tensor<3xf32> + // CHECK-NOT: tf_type.resource + // CHECK: func.call @next + // CHECK: ^ + // CHECK-SAME: tensor<6xf32> + // CHECK-SAME: tensor<2xf32> + // CHECK: func.call @finalize + // CHECK-NOT: tf.GeneratorDataset + %0 = "tf.GeneratorDataset"(%arg0, %arg1, %arg2, %arg3) { + device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", + finalize_func = @finalize, + init_func = @init, + next_func = @next, + operandSegmentSizes = array, + output_shapes = [#tf_type.shape<>], + output_types = [!tf_type.string], + metadata = ""} : ( + tensor<4xf32>, + tensor<3xf32>, + tensor, + tensor<2xf32>) -> tensor + return +} + +// ----- + +func.func @testIncompleteGeneratorDataset() { + // expected-error@+1 {{'tf.GeneratorDataset' op failed to convert to region form}} + %0 = "tf.GeneratorDataset"() { + finalize_func = @invalid, + init_func = @invalid, + next_func = @invalid, + output_shapes = [#tf_type.shape<>], + output_types = [!tf_type.string], + metadata = "" } : () -> tensor + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir b/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir deleted file mode 100644 index d7867332a4812c..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/host_launch_to_outside_compiled.mlir +++ /dev/null @@ -1,192 +0,0 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-device-host-launch-to-outside-compiled | FileCheck %s - -// Tests invalid device error returned when invalid device set on module. - -// expected-error@+1 {{not a valid device}} -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["bad_device"]} { - func.func @bad_device_error() -> () { - "tf_device.cluster"() ({ - "tf.A"() : () -> () - "tf_device.launch"() ({ - "tf.B"() : () -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } -} - -// ----- - -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - - // Tests the unwrap of unreplicated launch of a single outside compiled op with no input or output dependencies. - - // CHECK-LABEL: func @single_op_launch_not_host - func.func @single_op_launch_not_host() -> () { - // CHECK: "tf.A" - // CHECK: "tf_device.launch" - // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:TPU:0" - // CHECK: "tf.B" - // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.C" - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ({ - "tf.A"() : () -> () - "tf_device.launch"() ({ - "tf.B"() : () -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - - // CHECK-LABEL: func @single_op_hostlaunch_no_input_output - func.func @single_op_hostlaunch_no_input_output() -> () { - // CHECK: "tf.A" - // CHECK-NOT: "tf_device.launch" - // CHECK-NEXT: "tf.B" - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ({ - "tf.A"() : () -> () - "tf_device.launch"() ({ - "tf.B"() : () -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - - // CHECK-LABEL: func @single_op_host_launch_input_output - func.func @single_op_host_launch_input_output() -> () { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK-NOT: "tf_device.launch" - // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C"(%[[B_OUTPUT]]) - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ({ - %1 = "tf.A"() : () -> (tensor) - %2 = "tf_device.launch"() ({ - %3 = "tf.B"(%1) : (tensor) -> (tensor) - tf_device.return %3 : tensor - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor) - %4 = "tf.C"(%2) : (tensor) -> tensor - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - - // CHECK-LABEL: func @multiple_ops_host_launch_input_output - func.func @multiple_ops_host_launch_input_output() -> () { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK-NOT: "tf_device.launch" - // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) - // CHECK-SAME: _xla_outside_compilation - // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[B_OUTPUT]]) - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C"(%[[D_OUTPUT]]) - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ({ - %1 = "tf.A"() : () -> (tensor) - %2 = "tf_device.launch"() ({ - %3 = "tf.B"(%1) : (tensor) -> (tensor) - %4 = "tf.D"(%3) : (tensor) -> (tensor) - tf_device.return %4 : tensor - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor) - %5 = "tf.C"(%2) : (tensor) -> tensor - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - - // Tests a host launch that's called from a tf_device.cluster. - - func.func @called_hostlaunch() -> () { - "tf_device.cluster"() ({ - "tf.PartitionedCall"() {f = @called_hostlaunch_callee} : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - // CHECK-LABEL: func @called_hostlaunch_callee - func.func @called_hostlaunch_callee() -> () { - // CHECK: "tf.A" - // CHECK-NOT: "tf_device.launch" - // CHECK-NEXT: "tf.B" - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" - "tf.A"() : () -> () - "tf_device.launch"() ({ - "tf.B"() : () -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () - "tf.C"() : () -> () - func.return - } - - // Test that the same outside compiled function cannot be called from two - // different TPU clusters. - - func.func @called_hostlaunch_bad() -> () { - "tf_device.cluster"() ({ - "tf.PartitionedCall"() {f = @called_hostlaunch_bad_callee} : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - "tf_device.cluster"() ({ - "tf.PartitionedCall"() {f = @called_hostlaunch_bad_callee} : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - // expected-error@+1 {{The same function is reachable from multiple TPU Clusters.}} - func.func @called_hostlaunch_bad_callee() -> () { - // CHECK: "tf.A" - // CHECK-NOT: "tf_device.launch" - // CHECK-NEXT: "tf.B" - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" - "tf.A"() : () -> () - "tf_device.launch"() ({ - "tf.B"() : () -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () - "tf.C"() : () -> () - func.return - } -} - -// ----- - -// Checks that transform to outside compiled occurs when there is model -// parallelism. - -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { - // CHECK-LABEL: func @model_parallelism - func.func @model_parallelism() -> () { - // CHECK: "tf.A" - // CHECK-NOT: "tf_device.launch" - // CHECK-NEXT: "tf.B" - // CHECK-SAME: _xla_outside_compilation - // CHECK: "tf.C" - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ({ - "tf.A"() : () -> () - "tf_device.launch"() ({ - "tf.B"() : () -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {num_cores_per_replica = 2, topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1]} : () -> () - func.return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir deleted file mode 100644 index c0230b43d1db04..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir +++ /dev/null @@ -1,194 +0,0 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s - -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - - // Tests that TPU cluster with no outside compilation does not generate launch op. - - // CHECK-LABEL: func @no_outside_compilation - // CHECK-NOT: "tf_device.launch" - func.func @no_outside_compilation() -> tensor { - %0 = "tf_device.cluster"() ({ - %1 = "tf.A"() : () -> tensor - %2 = "tf.B"(%1) : (tensor) -> tensor - tf_device.return %2 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - func.return %0 : tensor - } - - - // Tests the launch wrap of a single outside compiled cluster with no input or output dependencies. - - // CHECK-LABEL: func @nodep_single_outside_compilation - func.func @nodep_single_outside_compilation() -> () { - // CHECK: "tf.A" - // CHECK: "tf_device.launch" - // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0" - // CHECK-NEXT: "tf.B" - // CHECK-NOT: _xla_outside_compilation - // CHECK-NEXT: tf_device.return - // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = "" - "tf_device.cluster"() ({ - "tf.A"() : () -> () - "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - - // Tests the launch wrap of a single outside compiled cluster with data parallelism. - - // CHECK-LABEL: func @single_outside_compilation_with_replicate - func.func @single_outside_compilation_with_replicate(%arg0: tensor) -> () { - // CHECK: "tf.A" - // CHECK: tf_device.replicate - // CHECK-NEXT: "tf_device.cluster" - // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf_device.launch" - // CHECK-SAME: device = "TPU_REPLICATED_HOST_0" - // CHECK-NEXT: "tf.C" - // CHECK-NOT: _xla_outside_compilation - // CHECK: tf_device.return - // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = "" - %0 = "tf.A"(%arg0) : (tensor) -> tensor - tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - "tf_device.cluster"() ({ - "tf.B"() : () -> () - "tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor) -> () - "tf.D"() : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - tf_device.return - } - func.return - } - - // Tests launch wrap of a single outside compiled cluster with input/output. - - // CHECK-LABEL: func @single_outside_compilation_input_output - func.func @single_outside_compilation_input_output(%arg0: tensor) -> tensor { - %0 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate - // CHECK: "tf_device.cluster" - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch" - // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) - // CHECK: tf_device.return %[[B_OUTPUT]] - // CHECK: "tf.C"(%[[LAUNCH_OUTPUT]]) - %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - %2 = "tf_device.cluster"() ({ - %3 = "tf.A"() : () -> (tensor) - %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - %5 = "tf.C"(%4) : (tensor) -> tensor - tf_device.return %5 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - tf_device.return %2 : tensor - } - - func.return %1 : tensor - } - - // Tests launch wrap of multiple outside compiled cluster with input/output. - - // CHECK-LABEL: func @multiple_outside_compilation_input_output - func.func @multiple_outside_compilation_input_output(%arg0: tensor) -> tensor { - %0 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate - // CHECK: "tf_device.cluster" - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch" - // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) - // CHECK: tf_device.return %[[B_OUTPUT]] - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]]) - // CHECK-NEXT: %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch" - // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) - // CHECK: tf_device.return %[[D_OUTPUT]] - // CHECK: %[[LAUNCH_OUTPUT3:[0-9]*]] = "tf_device.launch" - // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[LAUNCH_OUTPUT2]]) - // CHECK: tf_device.return %[[E_OUTPUT]] - // CHECK: "tf.F"(%[[LAUNCH_OUTPUT3]]) - %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - %2 = "tf_device.cluster"() ({ - %3 = "tf.A"() : () -> (tensor) - %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor - %5 = "tf.C"(%4) : (tensor) -> tensor - %6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor) -> tensor - %7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor) -> tensor - %8 = "tf.F"(%7) : (tensor) -> tensor - tf_device.return %8 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - tf_device.return %2 : tensor - } - - func.return %1 : tensor - } - - // Tests the launch wrap of an outside compiled op that's called from a tf_device.cluster. - - func.func @called_outside_compilation() -> () { - "tf_device.cluster"() ({ - "tf.PartitionedCall"() {f = @called_outside_compilation_callee} : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - // CHECK-LABEL: func @called_outside_compilation_callee - func.func @called_outside_compilation_callee() -> () { - // CHECK: "tf.A" - // CHECK: "tf_device.launch" - // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0" - // CHECK-NEXT: "tf.B" - // CHECK-NOT: _xla_outside_compilation - // CHECK-NEXT: tf_device.return - "tf.A"() : () -> () - "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.C"() : () -> () - func.return - } - - // Test that the same outside compiled function cannot be called from two - // different TPU clusters. - - func.func @called_outside_compilation_bad() -> () { - "tf_device.cluster"() ({ - "tf.PartitionedCall"() {f = @called_outside_compilation_bad_callee} : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - "tf_device.cluster"() ({ - "tf.PartitionedCall"() {f = @called_outside_compilation_bad_callee} : () -> () - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () - func.return - } - // expected-error@+1 {{The same function is reachable from multiple TPU Clusters.}} - func.func @called_outside_compilation_bad_callee() -> () { - "tf.A"() : () -> () - "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.C"() : () -> () - func.return - } -} - -// ----- - -// Tests that model parallelism does not affect outside compilation. - -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { - // CHECK-LABEL: func @outside_compilation_model_parallelism - func.func @outside_compilation_model_parallelism() -> () { - // CHECK: "tf.A" - // CHECK: "tf_device.launch" - // CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0" - // CHECK-NEXT: "tf.B" - // CHECK-NOT: _xla_outside_compilation - // CHECK-NEXT: tf_device.return - // CHECK: num_cores_per_replica = 2 : i64 - %0 = "tf_device.cluster"() ({ - "tf.A"() : () -> () - "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {num_cores_per_replica = 2, topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1]} : () -> tensor<2xi32> - func.return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index c57e07b5e3f74e..7246cdb4513280 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -2964,3 +2964,83 @@ func.func @global_iter_id_effect() -> () { // expected-remark@above {{ID: 6}} // expected-remark@above {{Sinks: {}}} } + +// ----- + +func.func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // expected-remark@above {{ID: 2}} + %sum = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + func.return %sum : tensor<1xf32> + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Sinks: {}}} +} + +func.func @intermediary(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // expected-remark@above {{ID: 2}} + %result = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config="", config_proto="", executor_type="", f=@add} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + func.return %result : tensor<1xf32> + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Sinks: {}}} +} + +// CHECK-LABEL: func @call_pure_function +func.func @call_pure_function(%arg0: tensor) -> tensor { + // expected-remark@above {{ID: 5}} + %one = "tf.Const"() { value = dense<1.0> : tensor<1xf32> } : () -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + %r1 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 1}} + %two = "tf.StatefulPartitionedCall"(%one, %one) {config="", config_proto="", executor_type="", f=@intermediary} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 2}} + %r2 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 3}} + func.return %arg0 : tensor + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Sinks: {1,3}}} +} + +// ----- + +func.func @assert(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor { + // expected-remark@above {{ID: 3}} + %cond = builtin.unrealized_conversion_cast to tensor + // expected-remark@above {{ID: 0}} + "tf.Assert"(%cond, %arg1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", summarize = 3 : i64} : (tensor, tensor<1xf32>) -> () + // expected-remark@above {{ID: 1}} + func.return %cond : tensor + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Sinks: {1}}} +} + +func.func @intermediary(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // expected-remark@above {{ID: 3}} + %cond = builtin.unrealized_conversion_cast to tensor + // expected-remark@above {{ID: 0}} + %sum = "tf.If"(%cond, %arg0, %arg1) { + then_branch = @assert, + else_branch = @assert, + is_stateless = false + } : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + // expected-remark@-5 {{ID: 1}} + func.return %arg0 : tensor<1xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Sinks: {1}}} +} + +// CHECK-LABEL: func @assert_within_if +func.func @assert_within_if(%arg0: tensor) -> tensor { + // expected-remark@above {{ID: 5}} + %one = "tf.Const"() { value = dense<1.0> : tensor<1xf32> } : () -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + %r1 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 1}} + %result = "tf.StatefulPartitionedCall"(%one, %one) {config="", config_proto="", executor_type="", f=@intermediary} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 2}} + %r2 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 3}} + func.return %arg0 : tensor + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Sinks: {1,3}}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 4333e79e0ee430..a7423b729dd287 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -32,12 +32,12 @@ func.func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -76,12 +76,12 @@ func.func @batchMatMulTwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x3x // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32> // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -109,9 +109,9 @@ func.func @batchMatMulOneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -131,7 +131,7 @@ func.func @batchMatMulSingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x6x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) <{axis = 0 : i64}> : (tensor<4x6xf32>) -> tensor<1x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32> @@ -152,9 +152,9 @@ func.func @batchMatMulUnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x6x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -175,9 +175,9 @@ func.func @batchMatMulUnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6 // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -190,7 +190,7 @@ func.func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> func.return %0 : tensor<4x6xf32> // CHECK-LABEL: batchMatMulMatrix - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } @@ -201,7 +201,7 @@ func.func @batchMatMulMatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf32> func.return %0 : tensor<4x6xf32> // CHECK-LABEL: batchMatMulMatrixAdjXY - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } @@ -238,12 +238,12 @@ func.func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6 // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -282,12 +282,12 @@ func.func @batchMatMulV2TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32> // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -319,12 +319,12 @@ func.func @batchMatMulV2Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -352,9 +352,9 @@ func.func @batchMatMulV2OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32 // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -374,7 +374,7 @@ func.func @batchMatMulV2SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) <{axis = 0 : i64}> : (tensor<4x6xf32>) -> tensor<1x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32> @@ -395,9 +395,9 @@ func.func @batchMatMulV2UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -418,9 +418,9 @@ func.func @batchMatMulV2UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5 // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -433,7 +433,7 @@ func.func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) - func.return %0 : tensor<4x6xf32> // CHECK-LABEL: batchMatMulV2Matrix - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } @@ -444,7 +444,7 @@ func.func @batchMatMulV2MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf3 func.return %0 : tensor<4x6xf32> // CHECK-LABEL: batchMatMulV2MatrixAdjXY - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } @@ -455,7 +455,7 @@ func.func @batchMatMulV2DynamicSize(%arg0: tensor, %arg1: tensor // CHECK-LABEL: batchMatMulV2DynamicSize - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor, tensor) -> tensor + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor, tensor) -> tensor // CHECK: return %[[MATMUL_1]] : tensor } @@ -492,12 +492,12 @@ func.func @batchMatMulV3TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6 // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -536,12 +536,12 @@ func.func @batchMatMulV3TwoDimAdjXY(%arg0: tensor<2x3x5x4xf32>, %arg1: tensor<2x // CHECK: %[[RHS_5:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#4, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32> // CHECK: %[[RHS_6:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#5, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x6x5xf32>, tensor<2xi64>) -> tensor<6x5xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_4]], %[[RHS_4]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_5]], %[[RHS_5]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_6]], %[[RHS_6]]) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -573,12 +573,12 @@ func.func @batchMatMulV3Broadcast(%arg0: tensor<2x1x4x5xf32>, %arg1: tensor<1x3x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_SPLIT]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_4:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_5:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_6:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]], %[[MATMUL_4]], %[[MATMUL_5]], %[[MATMUL_6]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> // CHECK: %[[RESULT:.*]] = "tf.Reshape"(%[[MATMUL_PACKED]], %[[RESULT_SHAPE]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> @@ -606,9 +606,9 @@ func.func @batchMatMulV3OneDim(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32 // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -628,7 +628,7 @@ func.func @batchMatMulV3SingleBatch(%arg0: tensor<1x4x5xf32>, %arg1: tensor<1x5x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%arg1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]]) <{axis = 0 : i64}> : (tensor<4x6xf32>) -> tensor<1x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<1x4x6xf32> @@ -649,9 +649,9 @@ func.func @batchMatMulV3UnbatchedLeft(%arg0: tensor<4x5xf32>, %arg1: tensor<3x5x // CHECK: %[[RHS_2:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#1, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> // CHECK: %[[RHS_3:.*]] = "tf.Reshape"(%[[RHS_RESHAPED]]#2, %[[MATMUL_RHS_SHAPE]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[RHS_1]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg0, %[[RHS_2]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[RHS_3]]) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -672,9 +672,9 @@ func.func @batchMatMulV3UnbatchedRight(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5 // CHECK: %[[LHS_2:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#1, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> // CHECK: %[[LHS_3:.*]] = "tf.Reshape"(%[[LHS_SPLIT]]#2, %[[MATMUL_LHS_SHAPE]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%[[LHS_1]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%[[LHS_2]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%[[LHS_3]], %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: %[[MATMUL_PACKED:.*]] = "tf.Pack"(%[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]]) <{axis = 0 : i64}> : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> // CHECK: return %[[MATMUL_PACKED]] : tensor<3x4x6xf32> @@ -687,7 +687,7 @@ func.func @batchMatMulV3Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) - func.return %0 : tensor<4x6xf32> // CHECK-LABEL: batchMatMulV3Matrix - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } @@ -698,7 +698,7 @@ func.func @batchMatMulV3MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf3 func.return %0 : tensor<4x6xf32> // CHECK-LABEL: batchMatMulV3MatrixAdjXY - // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = true, transpose_b = true}> : (tensor<5x4xf32>, tensor<6x5xf32>) -> tensor<4x6xf32> // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 58c3338ad2264e..3114f0d9546a5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -477,7 +477,6 @@ cc_library( "guarantee_all_funcs_one_use.cc", "hoist_loop_invariant.cc", "hoist_replicate_invariant_resource_writes.cc", - "host_launch_to_outside_compiled.cc", "init_text_file_to_import.cc", "launch_to_device_attribute.cc", "layout_optimization.cc", @@ -490,7 +489,6 @@ cc_library( "name_anonymous_iterators.cc", "optimize.cc", "order_by_dialect.cc", - "outside_compiled_to_host_launch.cc", "parallel_execute_to_islands.cc", "prepare_tpu_computation_for_tf_export.cc", "print.cc", @@ -566,9 +564,7 @@ cc_library( ":cluster_formation", ":decompose_resource_ops", ":decompose_resource_ops_inc_gen", - ":extract_outside_compilation", ":lower_tf_lib", - ":mark_ops_for_outside_compilation", ":shape_inference_pass", ":tensorflow_optimize_inc_gen", ":tf_data_optimization", @@ -578,7 +574,6 @@ cc_library( ":tfe_legalize_tfg", ":unroll_batch_matmul_pass", ":verify_no_outside_compilation_markers_pass", - ":xla_cluster_formation", "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:validators", @@ -617,6 +612,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "//tensorflow/compiler/mlir/tensorflow:xla_rewrite_util", "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:tpu_metadata_utils", "//tensorflow/compiler/mlir/tf2xla/internal/inference:inference_metrics_pass", "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", @@ -705,138 +701,24 @@ cc_library( ], ) -cc_library( - name = "xla_cluster_formation", - srcs = ["xla_cluster_formation.cc"], - textual_hdrs = [ - "tf_passes.h.inc", - ], - deps = [ - ":tf_device_pass_inc_gen", - ":tf_pass_inc_gen", - ":verify_no_outside_compilation_markers_pass", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:attribute_utils", - "//tensorflow/compiler/mlir/tensorflow:call_graph_util", - "//tensorflow/compiler/mlir/tensorflow:cluster_util", - "//tensorflow/compiler/mlir/tensorflow:string_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:portable_gif_internal", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], -) - -cc_library( - name = "extract_outside_compilation", - srcs = ["extract_outside_compilation.cc"], - textual_hdrs = [ - "tf_passes.h.inc", - ], - deps = [ - ":lower_tf_lib", - ":shape_inference_pass", - ":tf_pass_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:attribute_utils", - "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/mlir/tensorflow:string_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", - "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", - "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], -) - -cc_library( - name = "mark_ops_for_outside_compilation", - srcs = ["mark_ops_for_outside_compilation.cc"], - textual_hdrs = [ - "tf_passes.h.inc", - ], - deps = [ - ":lower_tf_lib", - ":tf_pass_inc_gen", - ":verify_no_outside_compilation_markers_pass", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:attribute_utils", - "//tensorflow/compiler/mlir/tensorflow:string_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", - "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", - "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], -) - cc_library( name = "bridge", srcs = ["bridge.cc"], hdrs = ["bridge.h"], deps = [ ":tensorflow_passes", - "//tensorflow/compiler/jit:flags_headers", - "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", - "//tensorflow/compiler/mlir/tf2xla/api/v1:tf_dialect_to_executor", - "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf", - "//tensorflow/compiler/mlir/tf2xla/api/v2:device_type_proto_cc", - "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_dialect_to_executor", - "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes", "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks", - "//tensorflow/compiler/mlir/tf2xla/internal/inference:inference_metrics_pass", - "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core/platform:error_payloads", - "//tensorflow/core/platform:stacktrace", "//tensorflow/core/platform:status", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/platform:error_logging", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index a7f1037b312544..07f399e53d1a3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -15,32 +15,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" -#include -#include -#include - +#include "absl/log/log.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h" -#include "tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h" -#include "tensorflow/core/framework/metrics.h" -#include "tensorflow/core/platform/error_payloads.h" -#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" -#include "tensorflow/core/util/debug_data_dumper.h" -#include "tsl/platform/error_logging.h" namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 4525bb7ce7db81..e403e4c6f7e960 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -87,13 +87,13 @@ def GatherToV2 : Pat< // with V1. def BatchMatMulToV2 : Pat< (TF_BatchMatMulOp:$src AnyStaticShapeTensor:$x, AnyStaticShapeTensor:$y, - $adj_x, $adj_y), - (TF_BatchMatMulV2Op:$dest $x, $y, $adj_x, $adj_y), + $adj_x, $adj_y, $grad_x, $grad_y), + (TF_BatchMatMulV2Op:$dest $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), [], [(CopyAttrs $src, $dest)]>; def BatchMatMulToMatMul : Pat< - (TF_BatchMatMulOp:$src $x, $y, $adj_x, $adj_y), - (TF_MatMulOp:$dest $x, $y, $adj_x, $adj_y), + (TF_BatchMatMulOp:$src $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), + (TF_MatMulOp:$dest $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), [(IsRank2Tensor $x), (IsRank2Tensor $y)], [(CopyAttrs $src, $dest)]>; @@ -102,8 +102,8 @@ def BatchMatMulToMatMul : Pat< //===----------------------------------------------------------------------===// def BatchMatMulV2ToMatMul : Pat< - (TF_BatchMatMulV2Op:$src $x, $y, $adj_x, $adj_y), - (TF_MatMulOp:$dest $x, $y, $adj_x, $adj_y), + (TF_BatchMatMulV2Op:$src $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), + (TF_MatMulOp:$dest $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), [(IsRank2Tensor $x), (IsRank2Tensor $y)], [(CopyAttrs $src, $dest)]>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 8f902c1eff7a0a..51afea6d84671e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -361,15 +361,14 @@ IslandOp CreateNewIsland(const MergedIsland& merged_island, // Creates respective YieldOp for the new merged island. YieldOp CreateNewIslandYieldOp(IslandOp new_island, - llvm::ArrayRef results) { + llvm::MutableArrayRef results) { llvm::SmallVector yield_operands; yield_operands.reserve(results.size()); - for (auto ret_vals : llvm::zip(results, new_island.getOutputs())) { - const auto& old_result = std::get<0>(ret_vals); - + for (auto [old_result, new_island] : + llvm::zip(results, new_island.getOutputs())) { // Replace original island result with new island result. - old_result.island_result.replaceAllUsesWith(std::get<1>(ret_vals)); + old_result.island_result.replaceAllUsesWith(new_island); // Add associated inner op result to operands of the YieldOp. yield_operands.push_back(old_result.inner_op_result); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index d88550d7920ab6..125cbbd6163c33 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -102,7 +102,9 @@ Value ConvertConditionToBoolean(Operation* op, Value cond) { return cond; OpBuilder builder(op); - return builder.create(op->getLoc(), cond); + Value to_bool = builder.create(op->getLoc(), cond); + CopyDeviceAndUnderscoredAttributes(op, to_bool.getDefiningOp()); + return to_bool; } // Transform a functional IfOp to a region based IfRegionOp. @@ -171,6 +173,48 @@ LogicalResult ConvertWhileOp(WhileOp while_op, bool allow_passthrough_args) { return success(); } +LogicalResult ConvertGeneratorDatasetOp(GeneratorDatasetOp generator_op) { + auto generator_region = + OpBuilder(generator_op) + .create( + generator_op.getLoc(), generator_op->getResultTypes(), + generator_op.getInitFuncOtherArgs(), + generator_op.getNextFuncOtherArgs(), + generator_op.getFinalizeFuncOtherArgs(), + generator_op.getOutputTypes(), generator_op.getOutputShapes(), + generator_op.getMetadata()); + CopyDeviceAndUnderscoredAttributes(generator_op, generator_region); + + func::FuncOp init_function = + SymbolTable::lookupNearestSymbolFrom( + generator_op, generator_op.getInitFunc()); + func::FuncOp next_function = + SymbolTable::lookupNearestSymbolFrom( + generator_op, generator_op.getNextFunc()); + func::FuncOp finalize_function = + SymbolTable::lookupNearestSymbolFrom( + generator_op, generator_op.getFinalizeFunc()); + + if (!init_function || !next_function || !finalize_function) { + return failure(); + } + + CreateCall(generator_op, init_function, generator_region.getInit(), + generator_region.getInitFuncOtherArgs(), + /*use_region_args=*/true, /*forward_block_args=*/false); + CreateCall(generator_op, next_function, generator_region.getNext(), + generator_region.getNextFuncOtherArgs(), + /*use_region_args=*/true, /*forward_block_args=*/false); + CreateCall(generator_op, finalize_function, generator_region.getFinalize(), + generator_region.getFinalizeFuncOtherArgs(), + /*use_region_args=*/true, /*forward_block_args=*/false); + + generator_op->replaceAllUsesWith(generator_region->getResults()); + generator_op->erase(); + + return success(); +} + void FunctionalControlFlowToRegions::runOnOperation() { ModuleOp module = getOperation(); auto result = module.walk([&](Operation* op) { @@ -189,6 +233,13 @@ void FunctionalControlFlowToRegions::runOnOperation() { op->emitOpError() << "failed to convert to region form"; return WalkResult::interrupt(); } + } else if (auto generator_op = llvm::dyn_cast(op)) { + if (allow_passthrough_args_) { + if (failed(ConvertGeneratorDatasetOp(generator_op))) { + op->emitOpError() << "failed to convert to region form"; + return WalkResult::interrupt(); + } + } } return WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_launch_to_outside_compiled.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_launch_to_outside_compiled.cc deleted file mode 100644 index 1c4383326e7625..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_launch_to_outside_compiled.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/FormatVariadic.h" -#include "mlir/Analysis/CallGraph.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" - -namespace mlir { -namespace TFDevice { - -namespace { - -constexpr char kDeviceAttr[] = "device"; -constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; - -#define GEN_PASS_DEF_HOSTLAUNCHTOOUTSIDECOMPILEDPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" - -struct HostLaunchToOutsideCompiledPass - : public impl::HostLaunchToOutsideCompiledPassBase< - HostLaunchToOutsideCompiledPass> { - void runOnOperation() override; -}; - -// Assign all ops in region with _xla_outside_compilation attribute. -void MarkOutsideCompiledInRegion(Region& region) { - region.walk([&](Operation* op) { - op->setAttr(kXlaOutsideCompilationAttr, - StringAttr::get(op->getContext(), "from_launch")); - }); -} - -void HoistOpsAndAnnotateWithOutsideCompilation(tf_device::LaunchOp launch) { - // Forward launch inner op results to launch op results. - launch.replaceAllUsesWith(launch.GetBody().getTerminator()->getOperands()); - - // For all inner ops, assign the launch device as a `device` attribute. - MarkOutsideCompiledInRegion(launch.getBody()); - - // Move all inner ops of the launch to the block containing the launch. - auto body = launch.GetBody().without_terminator(); - Operation* launch_op = launch.getOperation(); - launch_op->getBlock()->getOperations().splice( - launch_op->getIterator(), launch.GetBody().getOperations(), body.begin(), - body.end()); - - launch.erase(); -} - -void HostLaunchToOutsideCompiledPass::runOnOperation() { - auto traverse_op = [&](Operation* op, tf_device::ClusterOp tpu_cluster, - std::optional host_device) { - // Hoist launch. - if (tf_device::LaunchOp launch = dyn_cast(op)) { - StringAttr device_attr = launch->getAttrOfType(kDeviceAttr); - if (host_device && device_attr && - device_attr.getValue().equals(*host_device)) - HoistOpsAndAnnotateWithOutsideCompilation(launch); - } - return WalkResult::advance(); - }; - - ModuleOp module = getOperation(); - if (failed(TFTPU::WalkReachableFromTpuCluster(module, traverse_op))) - return signalPassFailure(); -} - -} // anonymous namespace - -std::unique_ptr> -CreateHostLaunchToOutsideCompiledPass() { - return std::make_unique(); -} - -} // namespace TFDevice -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index 359dd5c4624712..aa5097f19dbd5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -75,3 +75,60 @@ tf_cc_test( "@llvm-project//mlir:Pass", ], ) + +cc_library( + name = "tpu_metadata_utils", + srcs = [ + "tpu_metadata_utils.cc", + ], + hdrs = [ + "tpu_metadata_utils.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla:xla_proto_cc", + ], +) + +tf_cc_test( + name = "tpu_metadata_utils_test", + srcs = ["tpu_metadata_utils_test.cc"], + data = [ + "testdata/basic_cluster.mlir", + "testdata/spmd.mlir", + ], + deps = [ + ":tpu_metadata_utils", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index ab9a56e3b6db8e..3e3e8db504f1da 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -50,7 +50,6 @@ namespace { using mlir::DialectRegistry; using mlir::MLIRContext; using mlir::ModuleOp; -using mlir::OpPassManager; using mlir::OwningOpRef; using mlir::func::FuncOp; using ::tensorflow::monitoring::testing::CellReader; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/testdata/spmd.mlir b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/testdata/spmd.mlir new file mode 100644 index 00000000000000..21e27e013832f3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/testdata/spmd.mlir @@ -0,0 +1,9 @@ +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:2", "/job:localhost/replica:0/task:0/device:TPU:3", "/job:localhost/replica:0/task:0/device:TPU:4", "/job:localhost/replica:0/task:0/device:TPU:5", "/job:localhost/replica:0/task:0/device:TPU:6", "/job:localhost/replica:0/task:0/device:TPU:7"]} { + func.func @main(%arg0: tensor<*xf32> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) { + "tf_device.cluster_func"(%arg0) <{func = @empty_func}> {_dynamic_arg_index = [], _replication_info = "cluster", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], input_sharding_configuration = ["{devices=[2,1]0,1}"], num_cores_per_replica = 2 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<*xf32>) -> (tensor<*xf32>) + func.return + } + func.func @empty_func(%arg0: tensor<*xf32>) -> tensor<*xf32> { + func.return %arg0 : tensor<*xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc new file mode 100644 index 00000000000000..767d5cf7f0cf8c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc @@ -0,0 +1,250 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h" + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace mlir { +namespace TFTPU { +namespace { +constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; +constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning"; + +constexpr char kBadStringArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not a string"; +constexpr char kBadArrayElementMsg[] = + "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; +constexpr char kBadArrayAttrLengthMsg[] = + "bad '{0}' attribute, expected array attribute of size {1}, got size {2}"; + +// Creates a missing attribute error message. +std::string CreateMissingAttributeMsg(llvm::StringRef attribute) { + return llvm::formatv("requires attribute '{0}'", attribute).str(); +} + +// Populates a TPUCompileMetadataProto with StepMarkerLocation from a +// `tf_device::ClusterFuncOp`. +LogicalResult SetMetadataProtoStepMarkerLocation( + tf_device::ClusterFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + auto step_marker_location = + op->getAttrOfType(kStepMarkerLocationAttr); + if (!step_marker_location) + return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr)); + + // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is + // empty. + xla::DebugOptions::StepMarkerLocation location = + xla::DebugOptions::STEP_MARK_AT_ENTRY; + if (!step_marker_location.getValue().empty() && + !xla::DebugOptions::StepMarkerLocation_Parse( + std::string(step_marker_location.getValue()), &location)) + return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'", + kStepMarkerLocationAttr, + step_marker_location.getValue())); + + metadata->set_step_marker_location(location); + + return success(); +} + +// Parses a xla::OpSharding from a string attribute. +LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, + int index, xla::OpSharding* sharding_ptr) { + auto sharding_attr = attr.dyn_cast(); + if (!sharding_attr) + return op->emitOpError( + llvm::formatv(kBadStringArrayElementMsg, name, index)); + if (tensorflow::DecodeShardingAttribute(sharding_attr, *sharding_ptr) + .failed()) { + return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index, + sharding_attr.getValue(), + "xla::OpSharding")); + } + return success(); +} + +// Populates a TPUCompileMetadataProto with argument types and sharding from a +// `tf_device::ClusterFuncOp`. +LogicalResult SetMetadataProtoArgs( + tf_device::ClusterFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + auto input_shardings = + op->getAttrOfType(tensorflow::kInputShardingAttr); + if (!input_shardings) + return op.emitOpError( + CreateMissingAttributeMsg(tensorflow::kInputShardingAttr)); + + if (input_shardings.size() != op.getNumOperands()) + return op.emitOpError( + llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr, + op.getNumOperands(), input_shardings.size())); + + // Set args metadata in proto. + mlir::StringAttr replication_attr_name = mlir::StringAttr::get( + op.getContext(), "mhlo.is_same_data_across_replicas"); + + auto dynamic_arg_idx = op->getAttrOfType(TF::kDynamicArgIndexAttr); + llvm::SmallSet dynamic_arg_idx_set; + if (dynamic_arg_idx) { + for (auto idx : dynamic_arg_idx.getValue()) { + dynamic_arg_idx_set.insert(idx.dyn_cast().getInt()); + } + } + + for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) { + Type operand_type = operand_type_and_idx.value(); + int index = operand_type_and_idx.index(); + tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args(); + tensorflow::DataType dtype; + tensorflow::Status status = + tensorflow::ConvertToDataType(operand_type, &dtype); + if (!status.ok()) + return op.emitOpError( + llvm::formatv("failed to determine operand type at index {0}: {1}", + index, status.message())); + + arg->set_dtype(dtype); + // TODO(lyandy): Support other arg kinds. + if (dtype == tensorflow::DT_RESOURCE) + arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE); + else + arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER); + + // Populate argument shapes. + *arg->mutable_shape() = tensorflow::TensorShapeProto(); + if (auto ranked_tensor_type = operand_type.dyn_cast()) { + tensorflow::TensorShapeProto shape_proto; + ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto); + *arg->mutable_shape() = std::move(shape_proto); + } else { + arg->mutable_shape()->set_unknown_rank(true); + } + + if (failed(SetOpSharding(op, input_shardings.getValue()[index], + tensorflow::kInputShardingAttr, index, + arg->mutable_sharding()))) + return failure(); + + // Populate set_is_same_data_across_replicas + // Note: this information is duplicated and can be removed from the proto + // and here once MLIR bridge phase 2 doesn't fallback to the old bridge. + auto attr = op.getFuncOp().getArgAttrOfType( + index, replication_attr_name); + arg->set_is_same_data_across_replicas(attr != nullptr && attr.getValue()); + + // Currently only support first dimension to be bounded dynamic. + arg->mutable_is_bounded_dynamic_dim()->Add( + dynamic_arg_idx_set.contains(index)); + } + + return success(); +} + +// Populates a TPUCompileMetadataProto with result sharding from a +// `tf_device::ClusterFuncOp`. +LogicalResult SetMetadataProtoRetvals( + tf_device::ClusterFuncOp op, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + auto output_shardings = + op->getAttrOfType(tensorflow::kOutputShardingAttr); + if (!output_shardings) + return op.emitOpError( + CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr)); + + if (output_shardings.size() != op.getNumResults()) + return op.emitOpError( + llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr, + op.getNumResults(), output_shardings.size())); + + // Set retvals metadata in proto. + for (auto output_sharding_and_idx : llvm::enumerate(output_shardings)) + if (failed(SetOpSharding(op, output_sharding_and_idx.value(), + tensorflow::kOutputShardingAttr, + output_sharding_and_idx.index(), + metadata->add_retvals()->mutable_sharding()))) + return failure(); + + return success(); +} + +} // namespace + +// Populates a TPUCompileMetadataProto from attributes of a +// `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the +// op, a failure will be returned. +// TODO(lyandy): Support session handle and guaranteed consts. +LogicalResult SetMetadataProtoFromClusterFuncOp( + tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica, + std::optional&& xla_device_assignment, + tensorflow::tpu::TPUCompileMetadataProto* metadata) { + if (auto options_attr = + op->getAttrOfType("tpu_compile_options_proto")) { + if (!metadata->mutable_compile_options()->ParseFromArray( + options_attr.data(), options_attr.size())) { + return failure(); + } + } + metadata->set_num_replicas(num_replicas); + metadata->set_num_cores_per_replica(num_cores_per_replica); + + if (failed(SetMetadataProtoStepMarkerLocation(op, metadata))) + return failure(); + + if (xla_device_assignment.has_value()) + *metadata->mutable_device_assignment() = + std::move(xla_device_assignment.value()); + auto use_spmd_attr = op->getAttrOfType(kUseXlaSpmdAttr); + if (!use_spmd_attr) + return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr)); + metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue()); + + if (failed(SetMetadataProtoArgs(op, metadata))) return failure(); + + return SetMetadataProtoRetvals(op, metadata); +} + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h new file mode 100644 index 00000000000000..b58401eb6897d4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_TPU_METADATA_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_TPU_METADATA_UTILS_H_ + +#include + +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace mlir { +namespace TFTPU { + +// Populates a TPUCompileMetadataProto from attributes of a +// `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the +// op, a failure will be returned. +// TODO(lyandy): Support session handle and guaranteed consts. +LogicalResult SetMetadataProtoFromClusterFuncOp( + tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica, + std::optional&& xla_device_assignment, + tensorflow::tpu::TPUCompileMetadataProto* metadata); +} // namespace TFTPU +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_TPU_METADATA_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils_test.cc new file mode 100644 index 00000000000000..50fd035ebb8153 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h" + +#include +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" + +namespace mlir { +namespace TFTPU { +namespace { + +using mlir::DialectRegistry; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningOpRef; + +// TODO(b/229726259): Make EqualsProto available in OSS +class ProtoStringMatcher { + public: + explicit ProtoStringMatcher(const tsl::protobuf::Message& expected) + : expected_(expected.SerializeAsString()) {} + + template + bool MatchAndExplain(const Message& p, testing::MatchResultListener*) const { + return p.SerializeAsString() == expected_; + } + + void DescribeTo(::std::ostream* os) const { *os << expected_; } + void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_; + } + + private: + const std::string expected_; +}; + +inline ::testing::PolymorphicMatcher EqualsProto( + const tsl::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} + +std::string TestDataPath() { + return tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/testdata/"); +} + +class TpuMetadataUtilsTest : public ::testing::Test { + public: + TpuMetadataUtilsTest() { + mlir::RegisterCommonToolingDialects(registry_); + context_.appendDialectRegistry(registry_); + context_.loadAllAvailableDialects(); + } + + absl::StatusOr> GetClusterFuncOps( + absl::string_view mlir_module_filename) { + TF_RETURN_IF_ERROR(CreateMlirModule(mlir_module_filename)); + std::vector cluster_func_ops; + + mlir_module_->walk([&](mlir::tf_device::ClusterFuncOp op) { + cluster_func_ops.push_back(op); + }); + return cluster_func_ops; + } + + private: + absl::Status CreateMlirModule(absl::string_view mlir_module_filename) { + std::string mlir_module_path = + absl::StrCat(TestDataPath(), mlir_module_filename); + mlir_module_ = + mlir::parseSourceFile(mlir_module_path, &context_); + if (!mlir_module_) { + return absl::Status( + absl::StatusCode::kNotFound, + absl::StrCat("Could not find MLIR module at ", mlir_module_path)); + } + return absl::OkStatus(); + } + + DialectRegistry registry_; + MLIRContext context_; + OwningOpRef mlir_module_; +}; + +TEST_F(TpuMetadataUtilsTest, SingleDevice) { + TF_ASSERT_OK_AND_ASSIGN(auto cluster_func_ops, + GetClusterFuncOps("basic_cluster.mlir")); + mlir::tf_device::ClusterFuncOp cluster_func_op = cluster_func_ops.front(); + + tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + + ASSERT_TRUE(mlir::succeeded(SetMetadataProtoFromClusterFuncOp( + cluster_func_op, + /*num_replicas=*/1, /*num_cores_per_replica=*/1, {}, &compile_metadata))); + + tensorflow::tpu::TPUCompileMetadataProto expected_compile_metadata; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + num_replicas: 1 num_cores_per_replica: 1 + )pb", + &expected_compile_metadata)); + + EXPECT_THAT(compile_metadata, EqualsProto(expected_compile_metadata)); +} + +TEST_F(TpuMetadataUtilsTest, spmd) { + TF_ASSERT_OK_AND_ASSIGN(auto cluster_func_ops, + GetClusterFuncOps("spmd.mlir")); + mlir::tf_device::ClusterFuncOp cluster_func_op = cluster_func_ops.front(); + + tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + + ASSERT_TRUE(mlir::succeeded(SetMetadataProtoFromClusterFuncOp( + cluster_func_op, + /*num_replicas=*/1, /*num_cores_per_replica=*/2, {}, &compile_metadata))); + + tensorflow::tpu::TPUCompileMetadataProto expected_compile_metadata; + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + args { + dtype: DT_FLOAT + shape { unknown_rank: true } + kind: PARAMETER + sharding { + type: OTHER + tile_assignment_dimensions: 2 + tile_assignment_dimensions: 1 + tile_assignment_devices: 0 + tile_assignment_devices: 1 + } + is_bounded_dynamic_dim: false + } + retvals { sharding {} } + num_replicas: 1 + num_cores_per_replica: 2 + use_spmd_for_xla_partitioning: true + compile_options {} + )pb", + &expected_compile_metadata)); + + EXPECT_THAT(compile_metadata, EqualsProto(expected_compile_metadata)); +} + +} // namespace +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc deleted file mode 100644 index e710e76b03a3c5..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/outside_compiled_to_host_launch.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Analysis/CallGraph.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" - -namespace mlir { -namespace TFDevice { - -namespace { - -constexpr char kDeviceAttr[] = "device"; -constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; - -#define GEN_PASS_DEF_OUTSIDECOMPILEDTOHOSTLAUNCHPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" - -struct OutsideCompiledToHostLaunchPass - : public impl::OutsideCompiledToHostLaunchPassBase< - OutsideCompiledToHostLaunchPass> { - void runOnOperation() override; -}; - -void WrapOpInLaunch(Operation* host_op, llvm::StringRef host_device) { - OpBuilder builder(host_op); - - auto launch_op = builder.create( - host_op->getLoc(), builder.getStringAttr(host_device), - /*result_types=*/host_op->getResultTypes()); - host_op->replaceAllUsesWith(launch_op); - - launch_op.getBody().push_back(new Block); - builder.setInsertionPointToEnd(&launch_op.GetBody()); - auto* return_op = - builder - .create(host_op->getLoc(), host_op->getResults()) - .getOperation(); - MLIRContext* context = launch_op.getContext(); - host_op->removeAttr(StringAttr::get(context, kXlaOutsideCompilationAttr)); - host_op->removeAttr(StringAttr::get(context, kDeviceAttr)); - host_op->moveBefore(return_op); -} - -void OutsideCompiledToHostLaunchPass::runOnOperation() { - // traverse_op is applied to each op reachable from each tf_device::ClusterOp - // in the module returned by getOperation(). - auto traverse_op = [&](Operation* op, tf_device::ClusterOp tpu_cluster, - std::optional host_device) { - // Apply WrapOpInLaunch when the op has _xla_outside_compilation. - if (op->hasAttrOfType(kXlaOutsideCompilationAttr)) { - if (!host_device) { - tpu_cluster.emitOpError( - "outside compilation is not supported with model parallelism."); - return WalkResult::interrupt(); - } - WrapOpInLaunch(op, *host_device); - } - return WalkResult::advance(); - }; - if (failed(TFTPU::WalkReachableFromTpuCluster(getOperation(), traverse_op))) - return signalPassFailure(); -} - -} // anonymous namespace - -std::unique_ptr> -CreateOutsideCompiledToHostLaunchPass() { - return std::make_unique(); -} - -} // namespace TFDevice -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 12cf30eea9dfff..00bd6166c63521 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -458,15 +458,6 @@ std::unique_ptr> CreateParallelExecuteToIslandsPass( std::unique_ptr> CreateAnnotateParameterReplicationPass(); -// Creates a pass that marks unsupported ops in device cluster for outside -// compilation. -std::unique_ptr> -CreateMarkOpsForOutsideCompilationPass(); - -// Creates a pass that extract outside compilation (Host ops inside cevice -// cluster) ops to a separate parallel_execute region to run on CPU. -std::unique_ptr> CreateExtractOutsideCompilationPass(); - // Creates a pass that merges control flow with similar predicates. std::unique_ptr> CreateMergeControlFlowPass(); @@ -481,24 +472,11 @@ CreateDeviceAttributeToLaunchPass(); std::unique_ptr> CreateLaunchToDeviceAttributePass( bool legacy_graph_export = true); -// Creates a pass that extracts ops in tf_device.launch op with host device -// assignment and adds an `_xla_outside_compilation` attribute value. -std::unique_ptr> -CreateHostLaunchToOutsideCompiledPass(); - -// Creates a pass that wraps ops with the same `_xla_outside_compilation` -// attribute value in a tf_device.launch op with host device assignment. -std::unique_ptr> -CreateOutsideCompiledToHostLaunchPass(); - // Creates a pass to ensure that the `_xla_outside_compilation` and // tf_device.launch op no longer exist after Outside Compilation is complete. std::unique_ptr> CreateVerifyNoOutsideCompilationMarkersPass(); -// Create a pass that encapsulates StatefulPartitionedCallOp within a cluster. -std::unique_ptr> CreateXlaClusterFormationPass(); - // Create a pass that inlines the StatefulPartitionedCallOp op based in the // parent region. std::unique_ptr> CreateXlaInlineDeviceOpsPass(); @@ -677,7 +655,6 @@ enum MoveTransposeDirection { kBegin, kEnd }; #define GEN_PASS_DECL_LOCALIZEVARHANDLESPASS #define GEN_PASS_DECL_LOWERQUANTIZEDPASS #define GEN_PASS_DECL_MARKINPUTOUTPUTALIASESPASS -#define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_MATERIALIZEPASSTHROUGHOP #define GEN_PASS_DECL_MERGECONTROLFLOWPASS #define GEN_PASS_DECL_MOVETRANSPOSESPASS @@ -706,7 +683,6 @@ enum MoveTransposeDirection { kBegin, kEnd }; #define GEN_PASS_DECL_TPUCOLOCATECOMPOSITERESOURCEOPSPASS #define GEN_PASS_DECL_TPUDEVICEPROPAGATIONPASS #define GEN_PASS_DECL_TPUDYNAMICLAYOUTPASS -#define GEN_PASS_DECL_TPUEXTRACTOUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUHOSTCOMPUTATIONEXPANSIONPASS #define GEN_PASS_DECL_TPUIDENTITYPRUNINGPASS #define GEN_PASS_DECL_TPUMERGEVARIABLESWITHEXECUTEPASS diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index c458b2c6cd8725..afad8871b399ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -1896,7 +1896,8 @@ bool ShapeInference::InferShapeForXlaSelectAndScatterOp( bool ShapeInference::InferShapeForXlaGatherOp(XlaGatherOp op) { xla::Shape input_shape = xla::TypeToShape(op.getOperand().getType()); - if (input_shape == xla::Shape()) return false; + if (input_shape == xla::Shape() || input_shape.is_unbounded_dynamic()) + return false; xla::Shape start_indices_shape = xla::TypeToShape(op.getStartIndices().getType()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td index 8bfda6dbb25c55..c89c909375df67 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td @@ -154,80 +154,6 @@ def DeviceAttributeToLaunchPass : Pass<"tf-device-attribute-to-launch", "mlir::f let constructor = "TFDevice::CreateDeviceAttributeToLaunchPass()"; } -def OutsideCompiledToHostLaunchPass : Pass<"tf-outside-compiled-to-host-launch", "ModuleOp"> { - let summary = "Wraps each op with the _xla_outside_compiled attribute in a separate tf_device.launch on replicated host device."; - - let description = [{ - This pass wraps ops with the same `_xla_outside_compilation` - attribute value in a tf_device.launch op with host device assignment. The - `_xla_outside_compilation` attribute is deleted from the wrapped ops. - - A simple example: - - ```mlir - "tf_device.cluster"() ( { - "tf.A"() - "tf.B"() {_xla_outside_compilation = "cluster1"} - "tf.C"() - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} - ``` - - Would become the following ops (unimportant attribute, type are omitted): - - ```mlir - "tf_device.cluster"() ( { - "tf.A"() - "tf_device.launch"() { - "tf.B"() // Note xla_outside_compilation attribute deleted. - tf_device.return - } {device = "TPU_REPLICATED_HOST_0"} : () -> () - "tf.C"() - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} - ``` - }]; - - let constructor = "TFDevice::CreateOutsideCompiledToHostLaunchPass()"; -} - -def HostLaunchToOutsideCompiledPass : Pass<"tf-device-host-launch-to-outside-compiled", "ModuleOp"> { - let summary = "Converts each op wrapped in launch op with host device assignnment to op with _xla_outside_compiled attribute."; - - let description = [{ - This pass takes ops wrapped in a tf_device.launch op with host device - assignment extracts them from launch and adds an `_xla_outside_compilation` - attribute. This is the inverse of OutsideCompiledToHostLaunchPass. - - A simple example: - - ```mlir - "tf_device.cluster"() ( { - "tf.A"() - "tf_device.launch"() { - "tf.B"() - tf_device.return - } {device = "TPU_REPLICATED_HOST_0"} : () -> () - "tf.C"() - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} - ``` - - Would become the following ops (unimportant attribute, type are omitted): - - ```mlir - "tf_device.cluster"() ( { - "tf.A"() - "tf.B"() {_xla_outside_compilation = "cluster1"} - "tf.C"() - tf_device.return - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} - ``` - }]; - - let constructor = "TFDevice::CreateHostLaunchToOutsideCompiledPass()"; -} - def VerifyNoOutsideCompilationMarkersPass : Pass<"verify-no-outside-compilation-markers", "mlir::func::FuncOp"> { let summary = "Verifies that after Outside Compilation passes complete, there are no more _xla_outside_compilation attributes and no tf_device.launch ops."; @@ -337,50 +263,6 @@ def LaunchToDeviceAttributePass : Pass<"tf-launch-to-device-attribute", "mlir::f let constructor = "TFDevice::CreateLaunchToDeviceAttributePass()"; } -def XlaClusterFormationPass : Pass<"tf-xla-cluster-formation", "ModuleOp"> { - let summary = "Encapsulate partitioned calls within a Cluster op"; - let description = [{ - This pass clusters `tf.PartitionedCall` and `tf.StatefulPartitionedCall` - with `_xla_compile_device_type` attribute into a `tf_device.cluster`. - Notice this pass will only rewrite the outermost call if there are nested - calls to avoid nested `tf.XlaLaunch` operations from being created later. - - For example, the following code - - ```mlir - func.func @main() -> tensor { - %0 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> (tensor) - func.return %0 : tensor - } - - func.func @stateful_pcall_func() -> tensor { - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - func.return %0 : tensor - } - ``` - - will be transformed into, - - ```mlir - func.func @main() -> tensor { - %0 = "tf_device.cluster"() ({ - %1 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> tensor - tf_device.return %1 : tensor - }) : () -> tensor - func.return %0 : tensor - } - - func.func @stateful_pcall_func() -> tensor { - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - func.return %0 : tensor - } - - ``` - }]; - let constructor = "TFDevice::CreateXlaClusterFormationPass()"; - let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; -} - def XlaInlineDeviceOpsPass : Pass<"tf-xla-inline-device-ops", "ModuleOp"> { let summary = "Inline all Cluster op based in the parent region"; let constructor = "TFDevice::CreateXlaInlineDeviceOpsPass()"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index 782232cb3038f7..b8fa543318778c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -1226,62 +1226,6 @@ def TPUResourceReadForWritePass : Pass<"tf-tpu-resource-read-for-write", "Module let constructor = "TFTPU::CreateTPUResourceReadForWritePass()"; } -def ExtractOutsideCompilationPass : Pass<"tf-extract-outside-compilation", "ModuleOp"> { - let summary = "Extracts device outside compilation computation to a separate tf_device.parallel_execute region."; - - let description = [{ - This pass extracts a CPU computation cluster with `_xla_outside_compilation` - annotation, which denotes ops that should be run on CPU/host, from a device cluster. - Each outside compilation cluster is moved to - a tf_device.parallel_execute region. The device cluster is also moved to a - tf_device.parallel_execute region. Communication ops between device and host are - added to pass inputs/outputs to/from the outside compiled region. - - For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`: - - ```mlir - func @outside_compilation() -> tensor { - %0 = "tf_device.cluster"() ( { - %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor} : () -> (tensor) - %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor) -> (tensor) - %3 = "tf.AddV2"(%1, %2) : (tensor, tensor) -> (tensor) - tf_device.return %3 : tensor - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - return %0 : tensor - } - ``` - - will become a tf_device.parallel_execute op with a CPU/host region and - a tf_device.cluster with communication ops to send data to/from device/host: - - ```mlir - func @outside_compilation() -> tensor { - %0 = "tf_device.parallel_execute"() ( { - "tf_device.launch"() ( { - %1 = "tf._XlaCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string> - %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf_type.string>) -> tensor - %3 = "tf.Identity"(%2) : (tensor) -> tensor - "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor, tensor<3x!tf_type.string>) -> () - tf_device.return - }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () - tf_device.return - }, { - %1 = "tf_device.cluster"() ( { - %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor - %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor) -> tensor - %4 = "tf.AddV2"(%2, %3) : (tensor, tensor) -> tensor - tf_device.return %4 : tensor - }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor - tf_device.return %1 : tensor - }) : () -> tensor - return %0 : tensor - } - ``` - }]; - - let constructor = "TFDevice::CreateExtractOutsideCompilationPass()"; -} - def HoistReplicateInvariantResourceWritesPass : Pass<"tf-hoist-replicate-invariant-resource-writes", "mlir::func::FuncOp"> { let summary = "Hoists writes to replicate invariant resource variables."; @@ -1301,53 +1245,6 @@ def HoistReplicateInvariantResourceWritesPass : Pass<"tf-hoist-replicate-invaria let constructor = "TF::CreateHoistReplicateInvariantResourceWritesPass()"; } -def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation", "ModuleOp"> { - let summary = "Marks ops in device cluster for outside compilation if they are unsupported on device."; - - let description = [{ - This pass marks unsupported ops in a device cluster with - `_xla_outside_compilation` attribute so the operations will run on the host - instead of the device. Unsupported ops are ops that can not be code - generated to run on the device for the cluster including: - - 1. String operations on TPUs. - 2. Operations that don't have a kernel defined for the device. - - This pass is conservative in that it will mark all ops for outside compilation - that can not be compiled for the device. Exceptions for this are added for ops - that will be rewritten or decomposed before compiling on device. - - - For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp: - - ```mlir - func @unsupported_op() -> tensor { - %0 = "tf_device.cluster"() ( { - %1 = "tf.UnsupportedOp"() : () -> tensor - %2 = "tf.Identity"(%1) : (tensor) -> tensor - tf_device.return %2 : tensor - }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor - return %0 : tensor - } - ``` - - will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute: - - ```mlir - func @unsupported_op() -> tensor { - %0 = "tf_device.cluster"() ( { - %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor - %2 = "tf.Identity"(%1) : (tensor) -> tensor - tf_device.return %2 : tensor - }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor - return %0 : tensor - } - ``` - }]; - - let constructor = "TFDevice::CreateMarkOpsForOutsideCompilationPass()"; -} - def FunctionalControlFlowToRegionsPass : Pass<"tf-functional-control-flow-to-regions", "ModuleOp"> { let summary = "Transforms functional control flow operations to their region-based counterparts"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index f27e1f62f074fe..62afa2b10ed67b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -158,186 +159,6 @@ LogicalResult EncapsulateFuncAndSerialize(const std::string& module_name, return success(); } -// Populates a TPUCompileMetadataProto with StepMarkerLocation from a -// `tf_device::ClusterFuncOp`. -LogicalResult SetMetadataProtoStepMarkerLocation( - tf_device::ClusterFuncOp op, - tensorflow::tpu::TPUCompileMetadataProto* metadata) { - auto step_marker_location = - op->getAttrOfType(kStepMarkerLocationAttr); - if (!step_marker_location) - return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr)); - - // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is - // empty. - xla::DebugOptions::StepMarkerLocation location = - xla::DebugOptions::STEP_MARK_AT_ENTRY; - if (!step_marker_location.getValue().empty() && - !xla::DebugOptions::StepMarkerLocation_Parse( - std::string(step_marker_location.getValue()), &location)) - return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'", - kStepMarkerLocationAttr, - step_marker_location.getValue())); - - metadata->set_step_marker_location(location); - - return success(); -} - -// Parses a xla::OpSharding from a string attribute. -LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, - int index, xla::OpSharding* sharding_ptr) { - auto sharding_attr = attr.dyn_cast(); - if (!sharding_attr) - return op->emitOpError( - llvm::formatv(kBadStringArrayElementMsg, name, index)); - if (tensorflow::DecodeShardingAttribute(sharding_attr, *sharding_ptr) - .failed()) { - return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index, - sharding_attr.getValue(), - "xla::OpSharding")); - } - return success(); -} - -// Populates a TPUCompileMetadataProto with argument types and sharding from a -// `tf_device::ClusterFuncOp`. -LogicalResult SetMetadataProtoArgs( - tf_device::ClusterFuncOp op, - tensorflow::tpu::TPUCompileMetadataProto* metadata) { - auto input_shardings = - op->getAttrOfType(tensorflow::kInputShardingAttr); - if (!input_shardings) - return op.emitOpError( - CreateMissingAttributeMsg(tensorflow::kInputShardingAttr)); - - if (input_shardings.size() != op.getNumOperands()) - return op.emitOpError( - llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr, - op.getNumOperands(), input_shardings.size())); - - // Set args metadata in proto. - mlir::StringAttr replication_attr_name = mlir::StringAttr::get( - op.getContext(), "mhlo.is_same_data_across_replicas"); - - auto dynamic_arg_idx = op->getAttrOfType(TF::kDynamicArgIndexAttr); - llvm::SmallSet dynamic_arg_idx_set; - if (dynamic_arg_idx) { - for (auto idx : dynamic_arg_idx.getValue()) { - dynamic_arg_idx_set.insert(idx.dyn_cast().getInt()); - } - } - - for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) { - Type operand_type = operand_type_and_idx.value(); - int index = operand_type_and_idx.index(); - tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args(); - tensorflow::DataType dtype; - tensorflow::Status status = - tensorflow::ConvertToDataType(operand_type, &dtype); - if (!status.ok()) - return op.emitOpError( - llvm::formatv("failed to determine operand type at index {0}: {1}", - index, status.message())); - - arg->set_dtype(dtype); - // TODO(lyandy): Support other arg kinds. - if (dtype == tensorflow::DT_RESOURCE) - arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE); - else - arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER); - - // Populate argument shapes. - *arg->mutable_shape() = tensorflow::TensorShapeProto(); - if (auto ranked_tensor_type = operand_type.dyn_cast()) { - tensorflow::TensorShapeProto shape_proto; - ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto); - *arg->mutable_shape() = std::move(shape_proto); - } else { - arg->mutable_shape()->set_unknown_rank(true); - } - - if (failed(SetOpSharding(op, input_shardings.getValue()[index], - tensorflow::kInputShardingAttr, index, - arg->mutable_sharding()))) - return failure(); - - // Populate set_is_same_data_across_replicas - // Note: this information is duplicated and can be removed from the proto - // and here once MLIR bridge phase 2 doesn't fallback to the old bridge. - auto attr = op.getFuncOp().getArgAttrOfType( - index, replication_attr_name); - arg->set_is_same_data_across_replicas(attr != nullptr && attr.getValue()); - - // Currently only support first dimension to be bounded dynamic. - arg->mutable_is_bounded_dynamic_dim()->Add( - dynamic_arg_idx_set.contains(index)); - } - - return success(); -} - -// Populates a TPUCompileMetadataProto with result sharding from a -// `tf_device::ClusterFuncOp`. -LogicalResult SetMetadataProtoRetvals( - tf_device::ClusterFuncOp op, - tensorflow::tpu::TPUCompileMetadataProto* metadata) { - auto output_shardings = - op->getAttrOfType(tensorflow::kOutputShardingAttr); - if (!output_shardings) - return op.emitOpError( - CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr)); - - if (output_shardings.size() != op.getNumResults()) - return op.emitOpError( - llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr, - op.getNumResults(), output_shardings.size())); - - // Set retvals metadata in proto. - for (auto output_sharding_and_idx : llvm::enumerate(output_shardings)) - if (failed(SetOpSharding(op, output_sharding_and_idx.value(), - tensorflow::kOutputShardingAttr, - output_sharding_and_idx.index(), - metadata->add_retvals()->mutable_sharding()))) - return failure(); - - return success(); -} - -// Populates a TPUCompileMetadataProto from attributes of a -// `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the -// op, a failure will be returned. -// TODO(lyandy): Support session handle and guaranteed consts. -LogicalResult SetMetadataProtoFromClusterFuncOp( - tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica, - std::optional&& xla_device_assignment, - tensorflow::tpu::TPUCompileMetadataProto* metadata) { - if (auto options_attr = - op->getAttrOfType("tpu_compile_options_proto")) { - if (!metadata->mutable_compile_options()->ParseFromArray( - options_attr.data(), options_attr.size())) { - return failure(); - } - } - metadata->set_num_replicas(num_replicas); - metadata->set_num_cores_per_replica(num_cores_per_replica); - - if (failed(SetMetadataProtoStepMarkerLocation(op, metadata))) - return failure(); - - if (xla_device_assignment.has_value()) - *metadata->mutable_device_assignment() = - std::move(xla_device_assignment.value()); - auto use_spmd_attr = op->getAttrOfType(kUseXlaSpmdAttr); - if (!use_spmd_attr) - return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr)); - metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue()); - - if (failed(SetMetadataProtoArgs(op, metadata))) return failure(); - - return SetMetadataProtoRetvals(op, metadata); -} - // Create a `tf._TPUCompileMlir` that contains a MLIR module that is // functionally equivalent to the function referenced by cluster_func. Operation* BuildCompileOp( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc index 42b0516c7cb038..c044eff3e15c32 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc @@ -100,8 +100,10 @@ FailureOr RenameStablehloFunctions( MLIRContext *context, SymbolTableCollection &symbol_tables, ModuleOp tf_module, ModuleOp stablehlo_module) { SymbolTable &tf_symbol_table = symbol_tables.getSymbolTable(tf_module); - SymbolTable &stablehlo_symbol_table = - symbol_tables.getSymbolTable(stablehlo_module); + // `stablehlo_module` is deleted right after the deserialization, so no need + // to store its `SymbolTable` to `SymbolTableCollection`. + SymbolTable stablehlo_symbol_table(stablehlo_module); + Builder builder(context); StringAttr main_func_name; for (auto func : stablehlo_module.getOps()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc index 830dd1cb124705..f8752e316233dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc @@ -251,10 +251,8 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute, mlir::OpBuilder* builder) { - for (const auto& result_and_index : + for (auto [output_index, old_parallel_execute_output] : llvm::enumerate(old_parallel_execute.getResults())) { - const auto output_index = result_and_index.index(); - const auto old_parallel_execute_output = result_and_index.value(); const auto output_from_logical_device = new_parallel_execute.GetRegionOutputs(cluster_idx)[output_index]; old_parallel_execute_output.replaceAllUsesWith(output_from_logical_device); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD new file mode 100644 index 00000000000000..46af8590c8108e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -0,0 +1,339 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +cc_library( + name = "import_model", + srcs = [ + "import_model.cc", + ], + hdrs = [ + "export_graphdef.h", + "import_model.h", + ], + deps = [ + ":mlir_roundtrip_flags", + ":upgrade_graph", + "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:loader_util", + "//tensorflow/compiler/jit:shape_inference_helpers", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:convert_attr", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:initialize_variables_in_session_init", + "//tensorflow/compiler/mlir/tensorflow/transforms:lift_variables_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:mark_initialized_variables_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/platform:crash_analysis", + "//tensorflow/core/platform:types", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla:status_macros", + "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/service:hlo_parser", + ], +) + +tf_cc_test( + name = "tf_mlir_translate_registration_test", + size = "small", + srcs = ["tf_mlir_translate_registration_test.cc"], + deps = [ + ":translate_registration", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:TranslateLib", + ], +) + +cc_library( + name = "export_tf_dialect_op", + srcs = [ + "export_tf_dialect_op.cc", + ], + hdrs = [ + "export_tf_dialect_op.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/utils:string_container_utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:IR", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "translate_tf_dialect_op", + srcs = ["translate_tf_dialect_op.cc"], + deps = [ + ":export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_roundtrip_pass", + srcs = ["mlir_roundtrip_pass.cc"], + hdrs = ["mlir_roundtrip_pass.h"], + deps = [ + ":export_graphdef", + ":import_model", + ":mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "mlir_roundtrip_pass_registration", + srcs = ["mlir_roundtrip_pass_registration.cc"], + deps = [ + ":mlir_roundtrip_pass", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_roundtrip_flags", + srcs = ["mlir_roundtrip_flags.cc"], + hdrs = ["mlir_roundtrip_flags.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:types", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "mlir_import_options", + hdrs = ["mlir_import_options.h"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "translate_lib", + srcs = ["tf_mlir_translate.cc"], + hdrs = ["tf_mlir_translate.h"], + visibility = ["//visibility:public"], + deps = [ + ":import_model", + ":mlir_roundtrip_flags", + "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:reader", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_utils", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/util/tensor_bundle:byteswaptensor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], +) + +cc_library( + name = "translate_cl_options", + srcs = [ + "tf_mlir_translate_cl.cc", + ], + hdrs = [ + "tf_mlir_translate_cl.h", + ], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "export_graphdef", + srcs = [ + "export_graphdef.cc", + ], + hdrs = [ + "export_graphdef.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":export_tf_dialect_op", + ":mlir_roundtrip_flags", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tensorflow:verify_suitable_for_graph_export", + "//tensorflow/compiler/mlir/utils:name_utils", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/graph/regularization:util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla:status_macros", + ], +) + +cc_library( + name = "translate_registration", + srcs = [ + "tf_mlir_translate_registration.cc", + ], + deps = [ + ":export_graphdef", + ":mlir_roundtrip_flags", + ":translate_cl_options", + ":translate_lib", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/service/cpu:cpu_transfer_manager", + "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor/host:host_platform", + "@local_xla//xla/stream_executor/host:host_platform_id", + ], + alwayslink = 1, +) + +cc_library( + name = "split_into_island_per_op_pass", + srcs = ["split_into_island_per_op_pass.cc"], + hdrs = [ + "split_into_island_per_op_pass.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_executor_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "upgrade_graph", + srcs = ["upgrade_graph.cc"], + hdrs = ["upgrade_graph.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/tf2xla:functionalize_control_flow", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@llvm-project//llvm:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index 260caf3494be9c..4a19c06154b6d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -102,7 +102,7 @@ mlir::LogicalResult EvaluateOperation( RETURN_FAILURE_IF_ERROR(status); } - VLOG(1) << "Start to evaluate node: " << node_def->DebugString(); + VLOG(1) << "Start to evaluate node: " << *node_def; // Adds inputs to the TF operation. for (const auto operand : operands) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc new file mode 100644 index 00000000000000..7a6da9fcbd04d2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h" + +#include + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +std::string GetDeviceAttrAsResourceInstanceStr(mlir::Operation* op) { + auto device_attr = op->getAttrOfType("device"); + // Treat missing device attribute like unspecified (= empty string) attribute. + // Note that different op instances with the same string (including empty + // string) are seen as dependent (same resource instance). + if (!device_attr) return ""; + return device_attr.str(); +} + +void MarkResourceAsReadAndWrite( + Value value, + SmallVectorImpl>& + effects) { + if (value.getType().cast().getElementType().isa()) { + effects.emplace_back(MemoryEffects::Read::get(), value, + ResourceEffects::Variable::get()); + effects.emplace_back(MemoryEffects::Write::get(), value, + ResourceEffects::Variable::get()); + } +} + +void MarkResourceAsReadOnly( + Value value, + SmallVectorImpl>& + effects) { + if (value.getType().cast().getElementType().isa()) { + effects.emplace_back(MemoryEffects::Read::get(), value, + ResourceEffects::Variable::get()); + } +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h new file mode 100644 index 00000000000000..c55ad530f15962 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ + +#include + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +std::string GetDeviceAttrAsResourceInstanceStr(Operation* op); + +void MarkResourceAsReadAndWrite( + Value value, + SmallVectorImpl>& + effect); + +void MarkResourceAsReadOnly( + Value value, + SmallVectorImpl>& + effect); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 208311345c3f8e..c6ff5f5c93c6ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -189,10 +190,22 @@ std::string GetTPUCompilationDevice(ParsedDevice system_device) { } // Find the host CPU device for a given TPU device with `DEVICE_CPU` as its -// type and `id` 0. -std::string GetCPUHostDeviceForTPUDevice(ParsedDevice tpu_device) { +// type. If multiple local cpu devices are disabled, always assign id 0. If +// set, use the same id as the tpu device. +StatusOr GetCPUHostDeviceForTPUDevice(ParsedDevice tpu_device, + ParsedDevices devices) { tpu_device.type = DEVICE_CPU; - tpu_device.id = 0; + bool enable_multiple_local_cpu_devices = + tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_multiple_local_cpu_devices; + if (!enable_multiple_local_cpu_devices) { + tpu_device.id = 0; + } + if (FindMatchingDevices(devices, tpu_device).empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Can't find device: ", DeviceNameUtils::ParsedNameToString(tpu_device), + " in the devices list.")); + } return DeviceNameUtils::ParsedNameToString(tpu_device); } @@ -203,7 +216,8 @@ std::string GetCPUHostDeviceForTPUDevice(ParsedDevice tpu_device) { // number of TPU devices available, and `num_cores_per_replica` must be 1. StatusOr GetFullMeshTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, - llvm::ArrayRef> tpu_devices) { + llvm::ArrayRef> tpu_devices, + ParsedDevices devices) { const int num_tasks = tpu_devices.size(); const int num_tpus_per_task = tpu_devices[0].size(); const int num_tpu_devices = num_tasks * num_tpus_per_task; @@ -226,7 +240,7 @@ StatusOr GetFullMeshTPUExecutionDeviceAssignment( const auto& tpu_device = tpu_devices[task][device]; devices_and_hosts.push_back({TPUDeviceAndHost( /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device), - /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))}); + /*host=*/*GetCPUHostDeviceForTPUDevice(tpu_device, devices))}); } return devices_and_hosts; @@ -365,7 +379,7 @@ StatusOr> GetGeneralTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices, - llvm::StringRef topology_attr, + ParsedDevices devices, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr) { const int num_tasks = tpu_devices.size(); const int num_tpus_per_task = tpu_devices[0].size(); @@ -431,7 +445,7 @@ GetGeneralTPUExecutionDeviceAssignment( auto& device_and_host = devices_and_hosts[replica][logical_core]; const auto& tpu_device = tpu_devices[task][device]; device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device); - device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device); + device_and_host.host = *GetCPUHostDeviceForTPUDevice(tpu_device, devices); } } @@ -626,9 +640,10 @@ StatusOr GetTPUCompilationAndExecutionDevices( absl::StrCat("'", kDeviceAssignmentAttr, "' must not be set when '", kTopologyAttr, "' is not set")); - TF_ASSIGN_OR_RETURN(auto execution_devices, - GetFullMeshTPUExecutionDeviceAssignment( - num_replicas, num_cores_per_replica, tpu_devices)); + TF_ASSIGN_OR_RETURN( + auto execution_devices, + GetFullMeshTPUExecutionDeviceAssignment( + num_replicas, num_cores_per_replica, tpu_devices, devices)); return TPUDeviceAssignment(compilation_device, std::move(execution_devices)); } @@ -636,7 +651,7 @@ StatusOr GetTPUCompilationAndExecutionDevices( TF_ASSIGN_OR_RETURN(auto devices_and_ids, GetGeneralTPUExecutionDeviceAssignment( num_replicas, num_cores_per_replica, tpu_devices, - topology_attr, device_assignment_attr)); + devices, topology_attr, device_assignment_attr)); return TPUDeviceAssignment(compilation_device, std::move(devices_and_ids.first), std::move(devices_and_ids.second)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index fb88bc8bc44530..2c749b549cdc86 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -118,8 +118,10 @@ TEST_P(ParameterizedMetadataTest, BadMetadata) { ASSERT_TRUE(DeviceNamesToParsedNames( {"/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", + "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:TPU_SYSTEM:0", - "/job:worker/replica:0/task:1/device:TPU:0"}, + "/job:worker/replica:0/task:1/device:TPU:0", + "/job:worker/replica:0/task:1/device:CPU:0"}, &devices)); std::string compilation_device; llvm::SmallVector, 8> execution_devices; @@ -863,6 +865,7 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) { builder.getStrArrayAttr(llvm::ArrayRef( {"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:0/device:TPU:0", + "/job:localhost/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"}))); llvm::SmallVector result_types; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index c943f0c9ec3aa1..58adaa41349b14 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -639,10 +639,8 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute, mlir::OpBuilder* builder) { - for (const auto& result_and_index : + for (auto [output_index, old_parallel_execute_output] : llvm::enumerate(old_parallel_execute.getResults())) { - const auto output_index = result_and_index.index(); - const auto old_parallel_execute_output = result_and_index.value(); if (output_index < num_results_pre_cluster) { // Replace the use of those results of old parallel_execute op from host // with corresponding results of new parallel_execute op diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index d0653b9677d5c7..693d1f37766d81 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -3,13 +3,17 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", + "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", + ], ) cc_library( name = "compile_mlir_util_no_tf_dialect_passes", srcs = ["compile_mlir_util.cc"], hdrs = ["compile_mlir_util.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/tensorflow", @@ -46,6 +50,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:status", "//tensorflow/core/tpu:tpu_defs", + "@com_google_absl//absl/base:core_headers", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -94,7 +99,6 @@ cc_library( srcs = ["compile_tf_graph.cc"], hdrs = ["compile_tf_graph.h"], deps = [ - ":compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -105,6 +109,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:set_tpu_infeed_layout", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -113,13 +118,21 @@ cc_library( "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", "//tensorflow/core/tpu/kernels:tpu_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", + "@local_xla//xla:status_macros", "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/pjrt:compile_options_proto_cc", ], @@ -167,6 +180,9 @@ cc_library( name = "cluster_tf", srcs = ["cluster_tf.cc"], hdrs = ["cluster_tf.h"], + visibility = [ + "//tensorflow/compiler/tf2xla:__pkg__", + ], deps = [ ":tf_dialect_to_executor", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", @@ -225,6 +241,7 @@ cc_library( name = "tf_dialect_to_executor", srcs = ["tf_dialect_to_executor.cc"], hdrs = ["tf_dialect_to_executor.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc index 2f8469ee3f6f69..bb27edab8aa88a 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc @@ -103,6 +103,7 @@ tensorflow::Status RunTFXLABridge( } PassManager bridge(module.getContext()); + bridge.enableVerifier(); ::tensorflow::applyTensorflowAndCLOptions(bridge); // Populate a passmanager with the list of passes that implement the bridge. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index 12e2212ba81445..3f6e446ca28fd9 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/base/attributes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -63,8 +64,7 @@ namespace tensorflow { // result shapes. // custom_legalization_passes: passes to run before the default TF legalization // passes for backend-specific ops. -// -// TODO(hinsu): Migrate options to a separate struct. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, @@ -98,6 +98,7 @@ Status ConvertMLIRToXlaComputation( // true, includes legalization and MHLO lowering passes. // allow_partial_conversion: when this is true, allow operations that can't be // legalized. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") void CreateConvertMlirToXlaHloPipeline( mlir::OpPassManager& pm, llvm::StringRef device_type, bool enable_op_fallback, @@ -112,12 +113,14 @@ struct TensorOrResourceShape { }; // Refine MLIR types based on new shape information. +ABSL_DEPRECATED("Not meant to be used directly and should be a util.") Status RefineShapes(llvm::ArrayRef arg_shapes, mlir::ModuleOp module); // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level // inputs to module_op that have already been added to the XlaBuilder. returns // are the returned XlaOps. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, llvm::ArrayRef xla_params, std::vector& returns, @@ -129,6 +132,7 @@ Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, // Apply shape, description, and resource information to inputs and outputs // in the XlaCompilationResult. This should be called after // compilation_result->computation was set. +ABSL_DEPRECATED("Not meant to be used directly and should be a util.") Status PopulateResultIOInfo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, bool use_resource_updates_for_aliases, @@ -142,6 +146,7 @@ Status PopulateResultIOInfo( // // If enable_op_fallback is set to false, graph is legalized only if the graph // analysis for the graph is successful. Otherwise, an error is returned. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") StatusOr CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, @@ -157,6 +162,7 @@ StatusOr CompileMlirToXlaHlo( // // If lower_to_xla_hlo is true then compiles down into XLA HLO, generates all // accompanying metadata and stores them in CompilationResult. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") StatusOr CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, @@ -172,6 +178,7 @@ StatusOr CompileSerializedMlirToXlaHlo( // metadata and stores them in CompilationResult. This will rewrite arguments // and run the TensorFlow standard pipeline prior to invoking // `CompileMlirToXlaHlo`. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") Status CompileGraphToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, @@ -183,6 +190,8 @@ Status CompileGraphToXlaHlo( // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata // and stores them in CompilationResult. +ABSL_DEPRECATED( + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHloinstead.") Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef args, llvm::ArrayRef control_rets, llvm::StringRef device_type, @@ -197,6 +206,8 @@ Status CompileGraphToXlaHlo( // XlaBuilder. This function adds HLO to a larger HLO computation, so // HLO-level inputs are supplied, and HLO-level outputs are produced. // xla_params is the HLO-level inputs and returns is the HLO-level outputs. +ABSL_DEPRECATED( + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHloinstead.") Status BuildHloFromGraph( const Graph& graph, xla::XlaBuilder& builder, mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index ace94d1e17303d..003732ffb22f5a 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -15,12 +15,21 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" +#include #include #include #include #include +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -31,14 +40,30 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/compile_only_client.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/profile_utils/cpu_utils.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_compile.h" +#include "tsl/lib/monitoring/sampler.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -165,10 +190,8 @@ Status PrepareAndExportToLibrary(mlir::ModuleOp module, flib_def); } -} // namespace - -tsl::Status CompileTensorflowGraphToHlo( - const std::variant& computation, +tsl::Status CompileTFFunctionWithoutMlir( + FunctionToHloArgs function_computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, @@ -177,45 +200,40 @@ tsl::Status CompileTensorflowGraphToHlo( std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client, XlaCompiler::CompilationResult* compilation_result) { - LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using the " - "old (non-MLIR) tf2xla bridge"; - - *compilation_result = {}; - bool has_mlir = computation.index() == 0; - - std::string mlir_string = has_mlir ? "has_mlir" : "has_function_to_hlo"; - const std::string kBridgePhase2Config = - absl::StrCat("graph_old_bridge_", mlir_string); - CompilationTimer timer; - - if (!has_mlir) { - FunctionToHloArgs function_computation = std::get<1>(computation); - Status comp_status = CompileTFFunctionToHlo( - *function_computation.flib_def, function_computation.graph_def_version, - shape_determination_funcs, arg_shapes, - function_computation.guaranteed_constants, - *function_computation.function, metadata, client, arg_core_mapping, - per_core_arg_shapes, use_tuple_args, compilation_result); - if (comp_status.ok()) { - phase2_bridge_compilation_status->GetCell(kOldBridgeNoMlirSuccess) - ->IncrementBy(1); - } else { - phase2_bridge_compilation_status->GetCell(kOldBridgeNoMlirFailure) - ->IncrementBy(1); - } - - phase2_bridge_compilation_time->GetCell(kBridgePhase2Config) - ->Add(timer.ElapsedCyclesInMilliseconds()); - return comp_status; + Status comp_status = CompileTFFunctionToHlo( + *function_computation.flib_def, function_computation.graph_def_version, + shape_determination_funcs, arg_shapes, + function_computation.guaranteed_constants, *function_computation.function, + metadata, client, arg_core_mapping, per_core_arg_shapes, use_tuple_args, + compilation_result); + if (comp_status.ok()) { + phase2_bridge_compilation_status->GetCell(kOldBridgeNoMlirSuccess) + ->IncrementBy(1); + } else { + phase2_bridge_compilation_status->GetCell(kOldBridgeNoMlirFailure) + ->IncrementBy(1); } + return comp_status; +} + +tsl::Status CompileMLIRTFFunction( + tpu::MlirToHloArgs mlir_computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + const XlaShapeLayoutHelpers::ShapeDeterminationFns + shape_determination_funcs, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client, + XlaCompiler::CompilationResult* compilation_result) { mlir::DialectRegistry registry; mlir::RegisterAllTensorFlowDialects(registry); mlir::mhlo::registerAllMhloDialects(registry); mlir::MLIRContext context(registry); mlir::OwningOpRef mlir_module; - TF_RETURN_IF_ERROR(DeserializeMlirModule(std::get<0>(computation).mlir_module, + TF_RETURN_IF_ERROR(DeserializeMlirModule(mlir_computation.mlir_module, &context, &mlir_module)); if (!mlir::SetTPUInfeedLayout(mlir_module)) return errors::Internal("Failed to set layouts attribute"); @@ -256,11 +274,51 @@ tsl::Status CompileTensorflowGraphToHlo( consts, func, metadata, client, arg_core_mapping, per_core_arg_shapes, use_tuple_args, compilation_result)); + return PopulateInputOutputAliasing(main_fn, compilation_result, + use_tuple_args); +} + +} // namespace + +tsl::Status CompileTensorflowGraphToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + const XlaShapeLayoutHelpers::ShapeDeterminationFns + shape_determination_funcs, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client, + XlaCompiler::CompilationResult* compilation_result) { + LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using the " + "old (non-MLIR) tf2xla bridge"; + + CompilationTimer timer; + *compilation_result = {}; + bool has_mlir = computation.index() == 0; + + std::string mlir_string = has_mlir ? "has_mlir" : "has_function_to_hlo"; + const std::string kBridgePhase2Config = + absl::StrCat("graph_old_bridge_", mlir_string); + + if (has_mlir) { + TF_RETURN_IF_ERROR(CompileMLIRTFFunction( + std::get<0>(computation), metadata, use_tuple_args, + shape_determination_funcs, arg_shapes, arg_core_mapping, + per_core_arg_shapes, client, compilation_result)); + + } else { + FunctionToHloArgs function_computation = std::get<1>(computation); + TF_RETURN_IF_ERROR(CompileTFFunctionWithoutMlir( + function_computation, metadata, use_tuple_args, + shape_determination_funcs, arg_shapes, arg_core_mapping, + per_core_arg_shapes, client, compilation_result)); + } + phase2_bridge_compilation_time->GetCell(kBridgePhase2Config) ->Add(timer.ElapsedCyclesInMilliseconds()); - return PopulateInputOutputAliasing(main_fn, compilation_result, - use_tuple_args); + return tsl::OkStatus(); } }; // namespace v1 diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc index 236282f625e20a..9d0b884ebbe85d 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc @@ -127,6 +127,7 @@ tensorflow::Status ExportFromTensorflowDialectToExecutor( ModuleOp module, llvm::StringRef module_name) { PassManager tf_to_executor(module.getContext()); ::tensorflow::applyTensorflowAndCLOptions(tf_to_executor); + tf_to_executor.enableVerifier(); AddTfDialectToExecutorPasses(tf_to_executor); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 70a84bccff586a..73880851e7abc1 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -2,25 +2,11 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +# Please reach out to tf-bridge-team@ before using the TF2XLA bridge. package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":__subpackages__", - ":tf2xla_users", - ], -) - -# Please reach out to tf-bridge-team@ before using the TF2XLA bridge. -package_group( - name = "tf2xla_users", - packages = [ - "//tensorflow/compiler/mlir/quantization/stablehlo/...", - "//learning/serving/contrib/tfrt/mlir/saved_model_analysis", - "//tensorflow/compiler/mlir/tfrt", - "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/mlir", - # Legacy due to where the bridge currently runs. This should go away. - "//tensorflow/compiler/mlir/tensorflow/transforms", ], ) @@ -28,6 +14,12 @@ cc_library( name = "legalize_tf", srcs = ["legalize_tf.cc"], hdrs = ["legalize_tf.h"], + visibility = [ + "//learning/brain/google/xla:__pkg__", + "//learning/brain/mlir/bridge:__pkg__", + "//tensorflow/compiler/mlir/quantization/stablehlo:__pkg__", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__", + ], deps = [ ":device_type_proto_cc", "//tensorflow/compiler/jit:flags_headers", @@ -99,12 +91,22 @@ tf_proto_library( name = "device_type_proto", srcs = ["device_type.proto"], cc_api_version = 2, + visibility = [ + "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", + ], ) cc_library( name = "cluster_tf", srcs = ["cluster_tf.cc"], hdrs = ["cluster_tf.h"], + visibility = [ + "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tfrt:__pkg__", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__", + "//tensorflow/compiler/tf2xla:__pkg__", + ], deps = [ ":device_type_proto_cc", ":tf_dialect_to_executor", @@ -165,6 +167,12 @@ cc_library( name = "tf_dialect_to_executor", srcs = ["tf_dialect_to_executor.cc"], hdrs = ["tf_dialect_to_executor.h"], + visibility = [ + "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tfrt:__pkg__", + "//tensorflow/compiler/tf2xla:__pkg__", + ], deps = [ "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 24de1be6fe97dc..289d4d0faec78e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -73,6 +73,7 @@ tensorflow::Status RunTFXLABridge( } PassManager bridge(module.getContext()); + bridge.enableVerifier(); ::tensorflow::applyTensorflowAndCLOptions(bridge); // Populate a passmanager with the list of passes that implement the bridge. @@ -142,6 +143,10 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, } void CreateClusteringPipeline(OpPassManager &pm, llvm::StringRef module_name) { + // Since the internal bridge clustering passes are shared among TF1/TF2 + // TF2-only passes should go here. However, this should be very rare and + // new passes generally should go into the internal + // AddBridgeClusteringPipelinePasses. pm.addPass(mlir::TFTPU::CreateTPUValidateInputsPass()); pm.addNestedPass( mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc index 69f1c0e20a5e1b..455a59d6607c49 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc @@ -126,6 +126,7 @@ tensorflow::Status ExportFromTensorflowDialectToExecutor( ModuleOp module, llvm::StringRef module_name) { PassManager tf_to_executor(module.getContext()); ::tensorflow::applyTensorflowAndCLOptions(tf_to_executor); + tf_to_executor.enableVerifier(); AddTfDialectToExecutorPasses(tf_to_executor); if (VLOG_IS_ON(1) || diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index 1cad3d1d5cc615..a0261b398fcc8f 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -34,6 +34,8 @@ namespace internal { using mlir::OpPassManager; using mlir::func::FuncOp; +// LINT.IfChange(tpu_bridge_passes) + // Adds Bridge clustering pipeline passes to the given pass_manager. Does not // run them. void AddBridgeClusteringPipelinePasses(OpPassManager& pm, @@ -80,7 +82,6 @@ void AddBridgeClusteringPipelinePasses(OpPassManager& pm, // Run TPU cluster cleanup attributes so ops with no outside compiled // attribute have no host device attribute. pm.addPass(mlir::TFTPU::CreateTPUClusterCleanupAttributesPass()); - pm.addPass(mlir::TFDevice::CreateOutsideCompiledToHostLaunchPass()); pm.addNestedPass(mlir::TFDevice::CreateDeviceAttributeToLaunchPass()); // Running canonicalizer before decomposing resource ops in cluster helps the // latter pass to converge faster as it does not have to spend time folding @@ -97,10 +98,6 @@ void AddBridgeClusteringPipelinePasses(OpPassManager& pm, func_pm.addPass(mlir::TFTPU::CreateTPUHostComputationExpansionPass()); func_pm.addPass(mlir::TFTPU::CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); } - // TODO(b/173622615): This should incrementally be moved down as - // more passes support this representation and then can be removed once - // all passes support it. - pm.addPass(mlir::TFDevice::CreateHostLaunchToOutsideCompiledPass()); // TODO(b/173622615): Once OutsideCompilation is represented by launch op and // the remaining passes including Inliner support it, remove this @@ -109,9 +106,6 @@ void AddBridgeClusteringPipelinePasses(OpPassManager& pm, // will be removed from launch causing an error. pm.addNestedPass(mlir::TFDevice::CreateLaunchToDeviceAttributePass()); - // TODO(b/173622615): This can be removed once more passes support outside - // compilation represented by op and conversion back to attribute is removed. - pm.addPass(mlir::TFDevice::CreateOutsideCompiledToHostLaunchPass()); // Note that the region-based control-flow produced here still contains // function call ops which get inlined by the subsequent inliner pass. pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); @@ -138,15 +132,12 @@ void AddBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addPass(mlir::TFDevice::CreateMergeControlFlowPass()); } - // TODO(b/173622615): This should incrementally be moved down as - // more passes support this representation and then can be removed once - // all passes support it. - pm.addPass(mlir::TFDevice::CreateHostLaunchToOutsideCompiledPass()); - - pm.addPass(mlir::TFDevice::CreateMarkOpsForOutsideCompilationPass()); + pm.addPass( + tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()); pm.addPass(tensorflow::tf2xla::internal:: CreateExtractHeadTailOutsideCompilationPass()); - pm.addPass(mlir::TFDevice::CreateExtractOutsideCompilationPass()); + pm.addPass( + tensorflow::tf2xla::internal::CreateExtractOutsideCompilationPass()); pm.addNestedPass( mlir::TFDevice::CreateVerifyNoOutsideCompilationMarkersPass()); @@ -167,18 +158,21 @@ void AddBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addNestedPass( tensorflow::tf2xla::internal::CreateVerifyClusteringPass()); } +// LINT.ThenChange(:non_tpu_bridge_passes) void NoCanonicalization(OpPassManager& pm) {} +// LINT.IfChange(non_tpu_bridge_passes) void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { // The following ops must be preserved regardless of reachability. Ideally, // all graphs should have control dependencies to enforce this. VLOG(2) << "Create TF XLA Bridge pipeline"; + pm.addPass(mlir::TFDevice::CreateXlaValidateInputsPass()); pm.addNestedPass( mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); - // This pass expectes unified compilation markers. - pm.addPass(mlir::TFDevice::CreateXlaValidateInputsPass()); - const llvm::SmallVector ops_to_preserve = {}; + const llvm::SmallVector ops_to_preserve = { + "tf.TPUReplicateMetadata", "tf.TPUCompilationResult", + "tf.TPUReplicatedOutput"}; pm.addNestedPass( mlir::tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve)); // It is assumed at this stage there are no V1 control flow ops as Graph @@ -190,9 +184,17 @@ void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { // inference. pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // The following passe are addded to match TPU pipeline and expected to be + // no-op. + pm.addNestedPass(mlir::TFTPU::CreateTPUPartitionedOpConversionPass()); + pm.addNestedPass( + mlir::TFTPU::CreateTPUReorderReplicateAndPartitionedInputsPass()); + pm.addNestedPass(mlir::TF::CreateDecomposeReduceDatasetPass()); + pm.addPass(mlir::TFDevice::CreateEmbeddingPipeliningPass()); + pm.addPass(mlir::TFDevice::CreateEmbeddingSequencingPass()); // Encapsulate PartitionedCall ops within a cluster so that the composite // resource ops can be decomposed. - pm.addPass(mlir::TFDevice::CreateXlaClusterFormationPass()); + pm.addPass(tensorflow::tf2xla::internal::CreateXlaClusterFormationPass()); // Running canonicalizer before decomposing resource ops in cluster helps the // latter pass to converge faster as it does not have to spend time folding // away dead ops. @@ -223,10 +225,12 @@ void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { // for generic pipeline is landed. if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_generic_outside_compilation) { - pm.addPass(mlir::TFDevice::CreateMarkOpsForOutsideCompilationPass()); + pm.addPass( + tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()); pm.addPass(tensorflow::tf2xla::internal:: CreateExtractHeadTailOutsideCompilationPass()); - pm.addPass(mlir::TFDevice::CreateExtractOutsideCompilationPass()); + pm.addPass( + tensorflow::tf2xla::internal::CreateExtractOutsideCompilationPass()); } // Outline clusters into cluster functions. pm.addPass(mlir::TFDevice::CreateClusterOutliningPass()); @@ -234,6 +238,7 @@ void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { pm.addNestedPass( tensorflow::tf2xla::internal::CreateVerifyClusteringPass()); } +// LINT.ThenChange(:tpu_bridge_passes) }; // namespace internal }; // namespace tf2xla diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc index 91b80fa485a83f..d3201bffa137a0 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc @@ -28,14 +28,14 @@ TEST(ClusteringBridgePassesTest, AddsBridgePasses) { OpPassManager pass_manager; AddBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 47); + EXPECT_EQ(pass_manager.size(), 43); } TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) { OpPassManager pass_manager; AddNonTPUBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 15); + EXPECT_EQ(pass_manager.size(), 20); } }; // namespace internal diff --git a/tensorflow/compiler/mlir/tf2xla/internal/hlo_post_processing/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/hlo_post_processing/BUILD new file mode 100644 index 00000000000000..2c9500af0052ae --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/hlo_post_processing/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", + ], + licenses = ["notice"], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 0e25e62b150047..a391e189e6215c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -17,9 +17,6 @@ package( cc_library( name = "clustering_passes", - srcs = [ - "verify_clustering_pass.cc", - ], hdrs = [ "clustering_passes.h", ], @@ -27,14 +24,31 @@ cc_library( "clustering_passes.h.inc", ], deps = [ - ":clustering_passes_inc_gen", ":extract_head_tail_outside_compilation", + ":extract_outside_compilation", + ":mark_ops_for_outside_compilation", ":tpu_cluster_formation", + ":verify_clustering_pass", + ":xla_cluster_formation", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "verify_clustering_pass", + srcs = [ + "verify_clustering_pass.cc", + ], + deps = [ + ":clustering_passes_inc_gen", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:string_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tf2xla/internal/utils:dialect_detection_utils", "//tensorflow/core:framework", "//tensorflow/core/transforms/toposort:Pass", "@com_google_absl//absl/container:flat_hash_map", @@ -56,7 +70,7 @@ gentbl_cc_library( ( [ "-gen-pass-decls", - "-name=TFXLABridge", + "-name=TFXLABridgeClustering", ], "clustering_passes.h.inc", ), @@ -74,7 +88,6 @@ tf_cc_test( deps = [ ":clustering_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:test_utils", - "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -131,8 +144,8 @@ cc_library( ) cc_library( - name = "extract_head_tail_outside_compilation", - srcs = ["extract_head_tail_outside_compilation.cc"], + name = "extract_outside_compilation", + srcs = ["extract_outside_compilation.cc"], textual_hdrs = [ "clustering_passes.h.inc", ], @@ -141,13 +154,16 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:device_util", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", "//tensorflow/compiler/mlir/tensorflow:string_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:shape_inference_pass", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -161,43 +177,144 @@ cc_library( ) cc_library( - name = "dialect_to_executor_passes", - srcs = [ - "dialect_to_executor_passes.h", - ], + name = "extract_head_tail_outside_compilation", + srcs = ["extract_head_tail_outside_compilation.cc"], textual_hdrs = [ - "dialect_to_executor_passes.h.inc", + "clustering_passes.h.inc", ], deps = [ - ":dialect_to_executor_passes_inc_gen", + ":clustering_passes_inc_gen", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:device_util", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", "//tensorflow/core:framework", - "//tensorflow/core/transforms/toposort:Pass", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "mlir_to_graph_passes", + hdrs = [ + "mlir_to_graph_passes.h", + ], + textual_hdrs = [ + "mlir_to_graph_passes.h.inc", + ], + deps = [ + ":verify_input_dialect_to_executor_pass", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", ], ) gentbl_cc_library( - name = "dialect_to_executor_passes_inc_gen", + name = "mlir_to_graph_passes_inc_gen", compatible_with = get_compatible_with_portable(), tbl_outs = [ ( [ "-gen-pass-decls", - "-name=TFXLABridge", + "-name=TFXLABridgeMlirToGraph", ], - "dialect_to_executor_passes.h.inc", + "mlir_to_graph_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "dialect_to_executor_passes.td", + td_file = "mlir_to_graph_passes.td", deps = [ "@llvm-project//mlir:PassBaseTdFiles", ], ) + +cc_library( + name = "verify_input_dialect_to_executor_pass", + srcs = [ + "verify_input_dialect_to_executor_pass.cc", + ], + deps = [ + ":mlir_to_graph_passes_inc_gen", + "//tensorflow/compiler/mlir/tf2xla/internal/utils:dialect_detection_utils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "xla_cluster_formation", + srcs = ["xla_cluster_formation.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:call_graph_util", + "//tensorflow/compiler/mlir/tensorflow:cluster_util", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "mark_ops_for_outside_compilation", + srcs = ["mark_ops_for_outside_compilation.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index 79721a0da640ae..8062ac32b70bb0 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -37,10 +37,27 @@ CreateTPUClusterFormationPass(bool strict_clusters = false); std::unique_ptr> CreateExtractHeadTailOutsideCompilationPass(); +// Creates a pass that extract outside compilation (Host ops inside cevice +// cluster) ops to a separate parallel_execute region to run on CPU. +std::unique_ptr> +CreateExtractOutsideCompilationPass(); + +// Create a pass that encapsulates StatefulPartitionedCallOp within a cluster. +std::unique_ptr> +CreateXlaClusterFormationPass(); + +// Creates a pass that marks unsupported ops in device cluster for outside +// compilation. +std::unique_ptr> +CreateMarkOpsForOutsideCompilationPass(); + #define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS #define GEN_PASS_DECL_TPUEXTRACTHEADTAILOUTSIDECOMPILATIONPASS +#define GEN_PASS_DECL_TPUEXTRACTOUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_VERIFYCLUSTERINGPASS +#define GEN_PASS_DECL_XLACLUSTERFORMATIONPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" } // namespace internal diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index 4fc8af15ffa4fc..8dafe11afea4e3 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -153,5 +153,149 @@ def ExtractHeadTailOutsideCompilationPass : Pass<"tf-extract-head-tail-outside-c let constructor = "tensorflow::tf2xla::internal::CreateExtractHeadTailOutsideCompilationPass()"; } +def ExtractOutsideCompilationPass : Pass<"tf-extract-outside-compilation", "ModuleOp"> { + let summary = "Extracts device outside compilation computation to a separate tf_device.parallel_execute region."; + let description = [{ + This pass extracts a CPU computation cluster with `_xla_outside_compilation` + annotation, which denotes ops that should be run on CPU/host, from a device cluster. + Each outside compilation cluster is moved to + a tf_device.parallel_execute region. The device cluster is also moved to a + tf_device.parallel_execute region. Communication ops between device and host are + added to pass inputs/outputs to/from the outside compiled region. + + For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`: + + ```mlir + func @outside_compilation() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor} : () -> (tensor) + %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor) -> (tensor) + %3 = "tf.AddV2"(%1, %2) : (tensor, tensor) -> (tensor) + tf_device.return %3 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor + } + ``` + + will become a tf_device.parallel_execute op with a CPU/host region and + a tf_device.cluster with communication ops to send data to/from device/host: + + ```mlir + func @outside_compilation() -> tensor { + %0 = "tf_device.parallel_execute"() ( { + "tf_device.launch"() ( { + %1 = "tf._XlaCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string> + %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf_type.string>) -> tensor + %3 = "tf.Identity"(%2) : (tensor) -> tensor + "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor, tensor<3x!tf_type.string>) -> () + tf_device.return + }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () + tf_device.return + }, { + %1 = "tf_device.cluster"() ( { + %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor) -> tensor + %4 = "tf.AddV2"(%2, %3) : (tensor, tensor) -> tensor + tf_device.return %4 : tensor + }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return %0 : tensor + } + ``` + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateExtractOutsideCompilationPass()"; +} + +def XlaClusterFormationPass : Pass<"tf-xla-cluster-formation", "ModuleOp"> { + let summary = "Encapsulate partitioned calls within a Cluster op"; + let description = [{ + This pass clusters `tf.PartitionedCall` and `tf.StatefulPartitionedCall` + with `_xla_compile_device_type` attribute into a `tf_device.cluster`. + Notice this pass will only rewrite the outermost call if there are nested + calls to avoid nested `tf.XlaLaunch` operations from being created later. + + For example, the following code + + ```mlir + func.func @main() -> tensor { + %0 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> (tensor) + func.return %0 : tensor + } + + func.func @stateful_pcall_func() -> tensor { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + func.return %0 : tensor + } + ``` + + will be transformed into, + + ```mlir + func.func @main() -> tensor { + %0 = "tf_device.cluster"() ({ + %1 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + func.return %0 : tensor + } + + func.func @stateful_pcall_func() -> tensor { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + func.return %0 : tensor + } + + ``` + }]; + let constructor = "tensorflow::tf2xla::internal::CreateXlaClusterFormationPass()"; + let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; +} + +def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation", "ModuleOp"> { + let summary = "Marks ops in device cluster for outside compilation if they are unsupported on device."; + + let description = [{ + This pass marks unsupported ops in a device cluster with + `_xla_outside_compilation` attribute so the operations will run on the host + instead of the device. Unsupported ops are ops that can not be code + generated to run on the device for the cluster including: + + 1. String operations on TPUs. + 2. Operations that don't have a kernel defined for the device. + + This pass is conservative in that it will mark all ops for outside compilation + that can not be compiled for the device. Exceptions for this are added for ops + that will be rewritten or decomposed before compiling on device. + + + For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp: + + ```mlir + func @unsupported_op() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.UnsupportedOp"() : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor + } + ``` + + will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute: + + ```mlir + func @unsupported_op() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor + return %0 : tensor + } + ``` + }]; + let constructor = "tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()"; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc similarity index 86% rename from tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc rename to tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc index ccc72962fd141d..6bc3468a2729e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include -#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -34,33 +34,65 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" -namespace mlir { -namespace TFDevice { +namespace tensorflow { +namespace tf2xla { +namespace internal { namespace { +using llvm::ArrayRef; +using llvm::SmallVector; +using mlir::Block; +using mlir::BlockArgument; +using mlir::DenseIntElementsAttr; +using mlir::IRMapping; +using mlir::Location; +using mlir::LogicalResult; +using mlir::ModuleOp; +using mlir::OpBuilder; +using mlir::Operation; +using mlir::OperationPass; +using mlir::OpOperand; +using mlir::OpResult; +using mlir::OwningOpRef; +using mlir::RankedTensorType; +using mlir::StringAttr; +using mlir::StringRef; +using mlir::SymbolTable; +using mlir::Type; +using mlir::TypeRange; +using mlir::Value; +using mlir::ValueRange; +using mlir::WalkResult; +using mlir::func::FuncOp; +using mlir::func::ReturnOp; + constexpr char kDeviceAttr[] = "device"; constexpr char kHostFunctionAttr[] = "host_func"; constexpr char kXlaMapOutsideCompilationAttr[] = "_xla_map_outside_compilation"; @@ -68,7 +100,7 @@ constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kNoReplicationCluster[] = "__no_replication_cluster"; #define GEN_PASS_DEF_EXTRACTOUTSIDECOMPILATIONPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" struct ExtractOutsideCompilation : public impl::ExtractOutsideCompilationPassBase< @@ -79,9 +111,9 @@ struct ExtractOutsideCompilation // Build a function containing `ops` with `inputs` and `outputs` using // `builder`. The `ops` are cloned and modified to use the function arguments // as inputs. -func::FuncOp BuildFunction(llvm::ArrayRef ops, - llvm::ArrayRef inputs, - llvm::ArrayRef outputs, OpBuilder* builder) { +FuncOp BuildFunction(llvm::ArrayRef ops, + llvm::ArrayRef inputs, + llvm::ArrayRef outputs, OpBuilder* builder) { llvm::SmallVector operand_types; operand_types.reserve(inputs.size()); for (Value v : inputs) operand_types.emplace_back(v.getType()); @@ -91,8 +123,8 @@ func::FuncOp BuildFunction(llvm::ArrayRef ops, auto func_type = builder->getFunctionType(operand_types, output_types); - func::FuncOp outlined_func = - func::FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type); + FuncOp outlined_func = + FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type); // Create function body. Block* outlined_func_block = outlined_func.addEntryBlock(); @@ -111,13 +143,13 @@ func::FuncOp BuildFunction(llvm::ArrayRef ops, results_after_mapping.push_back(mapping.lookupOrDefault(result)); } - builder->create(ops.front()->getLoc(), results_after_mapping); + builder->create(ops.front()->getLoc(), results_after_mapping); return outlined_func; } // Encapsulates `func` in a module and serializes that module. // `serialized_func_module` is set to the serialized module. -void EncapsulateFuncAndSerialize(func::FuncOp func, +void EncapsulateFuncAndSerialize(FuncOp func, std::string* serialized_func_module) { // Create a new module to hold func and all referenced functions. OwningOpRef module_for_func = @@ -175,14 +207,14 @@ Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, llvm::StringRef communication_key) { if (device_ordinal) return ApplyXlaHostTransferAttr( - builder.create( + builder.create( loc, inputs, /*dynamic_key=*/compilation_key, device_ordinal, builder.getStringAttr(communication_key), device_type_attr), builder); return ApplyXlaHostTransferAttr( - builder.create( + builder.create( loc, inputs, /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key), @@ -200,13 +232,13 @@ Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc, llvm::StringRef communication_key) { if (device_ordinal) return ApplyXlaHostTransferAttr( - builder.create( + builder.create( loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal, builder.getStringAttr(communication_key), device_type_attr), builder); return ApplyXlaHostTransferAttr( - builder.create( + builder.create( loc, output_types, /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key), /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal), @@ -216,10 +248,10 @@ Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc, // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions // with yield op and an empty block. -TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region, - OpBuilder& builder) { +mlir::TF::IfRegionOp CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, + OpBuilder& builder) { // Mark op as stateful due to side-effecting communication ops added later. - auto host_side_if = builder.create( + auto host_side_if = builder.create( if_region.getLoc(), llvm::SmallVector{}, if_region.getCond(), /*is_stateless=*/false, if_region.get_thenFuncNameAttr(), if_region.get_elseFuncNameAttr()); @@ -228,23 +260,23 @@ TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region, auto& then_branch = host_side_if.getThenBranch(); then_branch.push_back(new Block); builder.setInsertionPointToEnd(&then_branch.front()); - builder.create(if_region.getLoc(), - /*operands=*/ArrayRef{}); + builder.create(if_region.getLoc(), + /*operands=*/ArrayRef{}); // Create empty else branch region. auto& else_branch = host_side_if.getElseBranch(); else_branch.push_back(new Block); builder.setInsertionPointToEnd(&else_branch.front()); - builder.create(if_region.getLoc(), - /*operands=*/ArrayRef{}); + builder.create(if_region.getLoc(), + /*operands=*/ArrayRef{}); return host_side_if; } // Creates a WhileRegionOp cond and body regions with yield op and // an empty body. -TF::WhileRegionOp CloneEmptyWhile(uint64_t parallel_iterations, Location loc, - OpBuilder& builder) { +mlir::TF::WhileRegionOp CloneEmptyWhile(uint64_t parallel_iterations, + Location loc, OpBuilder& builder) { // Mark op as stateful due to side-effecting communication ops added later. - auto host_side_while = builder.create( + auto host_side_while = builder.create( loc, /*output=*/ArrayRef{}, /*input=*/ArrayRef{}, parallel_iterations, /*is_stateless=*/false, /*shape_invariant=*/false); @@ -252,7 +284,7 @@ TF::WhileRegionOp CloneEmptyWhile(uint64_t parallel_iterations, Location loc, auto& body = host_side_while.getBody(); body.push_back(new Block); builder.setInsertionPointToEnd(&body.front()); - builder.create(loc, /*operands=*/ArrayRef{}); + builder.create(loc, /*operands=*/ArrayRef{}); return host_side_while; } @@ -261,16 +293,16 @@ TF::WhileRegionOp CloneEmptyWhile(uint64_t parallel_iterations, Location loc, // _XlaSendFromHost but the _XlaCompileMlir has not yet been created for device // cluster that contains the outside compiled ops. This placeholder should be // replaced by the TPU cluster _XlaCompileMlir in a subsequent pass. -TF::_XlaCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder( - Location loc, OpBuilder& builder) { +mlir::TF::_XlaCompileMlirPlaceholderProgramKeyOp +CreateCompilationKeyPlaceholder(Location loc, OpBuilder& builder) { auto result_type = - RankedTensorType::get({3}, builder.getType()); - return builder.create( + RankedTensorType::get({3}, builder.getType()); + return builder.create( loc, /*program=*/result_type, llvm::ArrayRef{}); } // Creates a `tf_device.launch` to wrap cluster ops. -tf_device::LaunchOp CreateLaunchOpForOutsideCluster( +mlir::tf_device::LaunchOp CreateLaunchOpForOutsideCluster( OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device, llvm::SmallVector& return_value_from_host) { llvm::SmallVector host_result_types; @@ -281,20 +313,21 @@ tf_device::LaunchOp CreateLaunchOpForOutsideCluster( // An empty string placeholder is used for the device as that will be later // populated with the device of the associated Device op. // For TPU case, it is TPUReplicateMetadata op. - auto launch_op = builder.create( + auto launch_op = builder.create( loc_op->getLoc(), builder.getStringAttr(host_device), /*result_types=*/host_result_types); launch_op.getBody().push_back(new Block); builder.setInsertionPointToEnd(&launch_op.GetBody()); - builder.create(loc_op->getLoc(), return_value_from_host); + builder.create(loc_op->getLoc(), + return_value_from_host); return launch_op; } // Returns true if `op` has non-static shaped outputs. bool HasDynamicOutputs(Operation* op) { for (Value v : op->getResults()) { - if (TF::CanBeRefined(v.getType())) return true; + if (mlir::TF::CanBeRefined(v.getType())) return true; } return false; } @@ -307,7 +340,7 @@ bool HasDynamicOutputs(const llvm::SmallSetVector& cluster_ops) { if (cluster_ops.count(use.getOwner())) { continue; } - if (TF::CanBeRefined(use.get().getType())) return true; + if (mlir::TF::CanBeRefined(use.get().getType())) return true; } } return false; @@ -317,7 +350,7 @@ bool HasDynamicExternalValues(Operation* op) { return op ->walk([](Operation* walked_op) { for (Value v : walked_op->getOperands()) { - if (TF::CanBeRefined(v.getType())) { + if (mlir::TF::CanBeRefined(v.getType())) { return WalkResult::interrupt(); } } @@ -330,14 +363,14 @@ bool HasDynamicExternalValues(Operation* op) { // communicated from device->host. This is for the case when all operands have a // static shape. llvm::SmallSetVector GetStaticExternalOperands( - tf_device::ClusterOp device_cluster, + mlir::tf_device::ClusterOp device_cluster, const llvm::SmallSetVector& cluster_ops) { llvm::SmallSetVector external_values; for (Operation* op : cluster_ops) { op->walk([&](Operation* walked_op) { - if (llvm::isa( - walked_op)) + if (llvm::isa(walked_op)) return WalkResult::advance(); for (Value v : walked_op->getOperands()) { if (!tensorflow::TypeValidForXLA(v.getType())) continue; @@ -347,8 +380,8 @@ llvm::SmallSetVector GetStaticExternalOperands( !HasOutsideCompilationAncestor(defining_op) && // Ignore operands that have already been received by a previously // created cluster. - !llvm::isa( - defining_op)) { + !llvm::isa(defining_op)) { external_values.insert(v); } continue; @@ -385,7 +418,7 @@ llvm::SmallSetVector GetAllExternalOperands( // Returns a SmallSetVector containing all of the operands that need to be // communicated from device->host. llvm::SmallSetVector GetExternalOperands( - tf_device::ClusterOp device_cluster, + mlir::tf_device::ClusterOp device_cluster, const llvm::SmallSetVector& cluster_ops) { // If there are any dynamic outputs, get all of the operands which are defined // external to `cluster_ops`. @@ -418,9 +451,11 @@ void GetExternalOutputs(const llvm::SmallSetVector& cluster_ops, if (!user_set.insert(user).second) continue; for (Value v : user->getOperands()) { if (tensorflow::TypeValidForXLA(v.getType()) && - v.getDefiningOp() == op && !isa(user)) + v.getDefiningOp() == op && + !llvm::isa(user)) external_outputs.insert(v); - if (v.getDefiningOp() == op && isa(user)) + if (v.getDefiningOp() == op && + llvm::isa(user)) tmp_host_outputs.push_back(v); } } @@ -464,7 +499,7 @@ LogicalResult GetShardShapedType(Operation* context_op, shape.push_back(in_shape[i]); } shard_type = RankedTensorType::Builder(ranked_type).setShape(shape); - return success(); + return mlir::success(); } // Output `sharding`, which is the sharding of `val`. `context_op` is used for @@ -483,7 +518,7 @@ LogicalResult GetShardingOfValue(Operation* context_op, Value val, << "A map_outside_compilation op's input should have an explicit " "sharding. There is no _XlaSharding attribute on the input op."; sharding = sharding_attr.str(); - return success(); + return mlir::success(); } // Create an `_XlaHostComputeMlir` for the map_outside_compilation case. Inputs @@ -508,7 +543,7 @@ LogicalResult CreateHostComputeMap( Type shard_type; if (failed(GetShardShapedType(original_op, num_cores_per_replica, output.getType(), shard_type))) - return failure(); + return mlir::failure(); shard_output_types.push_back(shard_type); full_output_types.push_back(output.getType()); } @@ -522,10 +557,10 @@ LogicalResult CreateHostComputeMap( Type shard_type; if (failed(GetShardShapedType(original_op, num_cores_per_replica, in.getType(), shard_type))) - return failure(); + return mlir::failure(); std::string in_sharding; if (failed(GetShardingOfValue(original_op, in, in_sharding))) - return failure(); + return mlir::failure(); if (common_split_sharding.empty()) { common_split_sharding = std::move(in_sharding); } else { @@ -534,14 +569,14 @@ LogicalResult CreateHostComputeMap( << "All inputs and outputs of map_outside_compilation should " "have the same sharding."; } - auto in_manual = builder.create( + auto in_manual = builder.create( loc, shard_type, in, common_split_sharding, /*dim=*/-1, /*unspecified_dims=*/builder.getI64ArrayAttr({})); manual_inputs.push_back(in_manual); } // Create the _XlaHostComputeMlirOp - auto host_compute = builder.create( + auto host_compute = builder.create( loc, shard_output_types, manual_inputs, /*send_key=*/builder.getStringAttr(args_communication_key), /*recv_key=*/builder.getStringAttr(retvals_communication_key), @@ -556,7 +591,7 @@ LogicalResult CreateHostComputeMap( if (!full_type_ranked) return original_op->emitOpError() << "map_outside_compilation must have ranked outputs"; - auto out_full = builder.create( + auto out_full = builder.create( loc, full_type, out, common_split_sharding, full_type_ranked.getShape(), /*dim=*/-1, /*unspecified_dims=*/builder.getI64ArrayAttr({})); @@ -564,7 +599,7 @@ LogicalResult CreateHostComputeMap( full_outputs.push_back(out_full); } - return success(); + return mlir::success(); } // Create the _XlaHostComputeMlir with `inputs` and `outputs` for the ordinary @@ -581,7 +616,7 @@ void CreateHostComputeNotMap(OpBuilder& builder, Location loc, llvm::SmallVector device_output_types; for (const auto& output : outputs) device_output_types.push_back(output.getType()); - auto host_compute = builder.create( + auto host_compute = builder.create( loc, device_output_types, inputs, builder.getStringAttr(args_communication_key), builder.getStringAttr(retvals_communication_key), @@ -612,7 +647,7 @@ LogicalResult CreateHostCompute( args_communication_key, retvals_communication_key, serialized_func_module, full_outputs, host_compute_out_ops); - return success(); + return mlir::success(); } } @@ -630,17 +665,18 @@ bool ShouldCloseCluster(llvm::ArrayRef outputs) { bool has_dynamic_output = false; bool has_nonxla_output = false; for (Value v : outputs) { - if (TF::CanBeRefined(v.getType())) { + if (mlir::TF::CanBeRefined(v.getType())) { has_dynamic_output = true; for (Operation* user : v.getUsers()) { if (!HasOutsideCompilationAncestor(user) && - !isa(user)) + !llvm::isa(user)) return true; } } if (!tensorflow::TypeValidForXLA(v.getType())) for (const Operation* user : v.getUsers()) - if (!isa(user)) has_nonxla_output = true; + if (!llvm::isa(user)) + has_nonxla_output = true; } return !has_nonxla_output && !has_dynamic_output; @@ -656,7 +692,7 @@ void ReplaceExternalOperandUsage(ArrayRef external_operands, Operation* insertion_point, Block* original_op_block) { auto replace_operand_usage = [&](OpOperand& operand) { - if (TF::CanBeRefined(operand.get().getType()) || + if (mlir::TF::CanBeRefined(operand.get().getType()) || HasDynamicOutputs(operand.getOwner())) { return insertion_point->getParentRegion()->isAncestor( operand.getOwner()->getParentRegion()); @@ -675,7 +711,7 @@ void ReplaceExternalOperandUsage(ArrayRef external_operands, bool HasDynamicOutputs(llvm::ArrayRef outputs) { for (Value v : outputs) { - if (TF::CanBeRefined(v.getType())) { + if (mlir::TF::CanBeRefined(v.getType())) { return true; } } @@ -723,7 +759,7 @@ std::pair MakeCommunicationKeys( // Use a unique name when sending just the IfRegion predicate. This is // for readable and to match the key in the TF2XLA bridge. - if (clustered_ops.size() == 1 && llvm::isa(op) && + if (clustered_ops.size() == 1 && llvm::isa(op) && external_operands.size() == 1) { args_communication_key = llvm::formatv("if_predicate_channel_{0}", (communication_key_index)) @@ -786,20 +822,21 @@ void CloneFirstHost(llvm::SmallVector& core_to_mapping, builder.setInsertionPoint(core_to_host_insertion_point[core]); Operation* clone = builder.clone(*op, core_to_mapping[core]); core_to_mapping[core].map(op, clone); - if (auto recv_at_host = llvm::dyn_cast(clone)) { + if (auto recv_at_host = + llvm::dyn_cast(clone)) { recv_at_host.setDeviceOrdinal(core); clone->setOperand(0, core_to_compilation_key[core]); } else if (auto send_from_host = - llvm::dyn_cast(clone)) { + llvm::dyn_cast(clone)) { send_from_host.setDeviceOrdinal(core); clone->setOperand(1, core_to_compilation_key[core]); } else if (auto recv_at_host = - llvm::dyn_cast(clone)) { + llvm::dyn_cast(clone)) { recv_at_host.setOperand(0, core_to_compilation_key[core]); builder.setInsertionPoint(recv_at_host); recv_at_host.setOperand(1, core_to_device_ordinal[core]); } else if (auto send_from_host = - llvm::dyn_cast(clone)) { + llvm::dyn_cast(clone)) { send_from_host.setOperand(1, core_to_compilation_key[core]); builder.setInsertionPoint(send_from_host); send_from_host.setOperand(2, core_to_device_ordinal[core]); @@ -830,8 +867,8 @@ LogicalResult MoveToHostSingleCluster( std::string serialized_func_module; if (HasDynamicOutputs(external_outputs)) { - func::FuncOp shape_op = BuildFunction(clustered_ops, external_operands, - external_outputs, &builder); + FuncOp shape_op = BuildFunction(clustered_ops, external_operands, + external_outputs, &builder); EncapsulateFuncAndSerialize(shape_op, &serialized_func_module); } @@ -843,7 +880,7 @@ LogicalResult MoveToHostSingleCluster( args_communication_key, retvals_communication_key, serialized_func_module, is_map_oc, num_cores_per_replica, common_split_sharding, host_compute_outputs, host_compute_out_ops))) - return failure(); + return mlir::failure(); // Insert ops on the host side computation to receive data from device. // host0_ops are the ops that will make up the first host process. In the @@ -881,7 +918,7 @@ LogicalResult MoveToHostSingleCluster( ++communication_key_index; } - return success(); + return mlir::success(); } // Update is_map_oc the true if op has attribute _xla_map_outside_compilation @@ -903,7 +940,7 @@ LogicalResult UpdateIsMapOutsideCompilation(Operation& op, bool control_above, return op.emitOpError() << "map_outside_compilation inside control flow " "is not implemented."; } - return success(); + return mlir::success(); } // Move outside compiled ops in `src` to `insertion_point` in host @@ -920,7 +957,7 @@ LogicalResult UpdateIsMapOutsideCompilation(Operation& op, bool control_above, // program. Currently only map_outside_compilation-only or ordinary // outside_compilation only is supported. LogicalResult MoveToHostMultiCluster( - tf_device::ClusterOp device_cluster, Block* src, + mlir::tf_device::ClusterOp device_cluster, Block* src, ArrayRef core_to_host_insertion_point, ArrayRef core_to_compilation_key, ArrayRef core_to_device_ordinal, int default_device_ordinal, @@ -938,8 +975,8 @@ LogicalResult MoveToHostMultiCluster( // single op except in the case where some of the input/output shapes are // non-static. llvm::SmallSetVector clustered_ops; - auto device_type_attr = - device_cluster->getAttrOfType(TF::kCompileDeviceTypeAttr); + auto device_type_attr = device_cluster->getAttrOfType( + mlir::TF::kCompileDeviceTypeAttr); for (Operation& op : llvm::make_early_inc_range(*src)) { if (HasOutsideCompilationAncestorExclusive(&op) || @@ -947,7 +984,7 @@ LogicalResult MoveToHostMultiCluster( continue; if (failed(UpdateIsMapOutsideCompilation(op, control_above, is_map_oc))) - return failure(); + return mlir::failure(); llvm::SmallSetVector external_outputs; llvm::SmallVector host_outputs; @@ -971,7 +1008,7 @@ LogicalResult MoveToHostMultiCluster( core_to_device_ordinal, default_device_ordinal, device_type_attr, *is_map_oc, num_cores_per_replica, common_split_sharding, communication_key_index))) - return failure(); + return mlir::failure(); clustered_ops.clear(); } @@ -999,18 +1036,18 @@ LogicalResult MoveToHostMultiCluster( core_to_device_ordinal, default_device_ordinal, device_type_attr, *is_map_oc, num_cores_per_replica, common_split_sharding, communication_key_index))) - return failure(); + return mlir::failure(); clustered_ops.clear(); } } - return success(); + return mlir::success(); } void GetReturnValueFromDevice( - tf_device::ClusterOp device_cluster, + mlir::tf_device::ClusterOp device_cluster, const llvm::SmallVector& return_value_from_host, llvm::SmallVector& return_value_from_device) { - if (auto return_op = llvm::dyn_cast_or_null( + if (auto return_op = llvm::dyn_cast_or_null( device_cluster.GetBody().getTerminator())) { for (auto v : return_op.getOperands()) { if (absl::c_count(return_value_from_host, v) == 0) { @@ -1028,14 +1065,14 @@ void GetReturnValueFromDevice( // launch in tf_device.parallel_execute. Uses `compilation_key, // `device_ordinal` and `communication_key_index` when creating communication // ops. -LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, +LogicalResult DecomposeControlFlow(mlir::tf_device::ClusterOp device_cluster, ArrayRef core_to_compilation_key, ArrayRef core_to_device_ordinal, int default_device_ordinal, int& communication_key_index, std::optional& is_map_oc) { auto result = device_cluster.GetBody().walk([&](Operation* op) { - if (auto if_op = llvm::dyn_cast(op)) { + if (auto if_op = llvm::dyn_cast(op)) { if (!HasOutsideCompilationNested(op)) return WalkResult::advance(); OpBuilder builder(if_op); auto host_if = CloneEmptyIfWithPredicate(if_op, builder); @@ -1057,7 +1094,7 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, if_op->setAttr("is_stateless", builder.getBoolAttr(false)); MarkOutsideCompiled(host_if.getOperation()); } - if (auto while_op = llvm::dyn_cast(op)) { + if (auto while_op = llvm::dyn_cast(op)) { if (!HasOutsideCompilationNested(op)) return WalkResult::advance(); OpBuilder builder(while_op); auto host_while = CloneEmptyWhile(while_op.getParallelIterations(), @@ -1071,8 +1108,8 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, auto condition = while_op.getCond().front().getTerminator()->getOperand(0); builder.setInsertionPoint(while_op.getCond().front().getTerminator()); - builder.create(while_op.getLoc(), condition, - condition_send_recv_key); + builder.create(while_op.getLoc(), condition, + condition_send_recv_key); // device_ordinal0 is the ordinal of TPU_REPLICATED_CORE_0 and is only // used in the replicated case. Value device_ordinal0 = nullptr; @@ -1082,10 +1119,11 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, auto recv_condition_at_host = CreateRecvAtHostOp( builder, while_op.getLoc(), TypeRange{condition.getType()}, core_to_compilation_key[0], device_ordinal0, default_device_ordinal, - device_cluster->getAttrOfType(TF::kCompileDeviceTypeAttr), + device_cluster->getAttrOfType( + mlir::TF::kCompileDeviceTypeAttr), condition_send_recv_key); - builder.create(while_op.getLoc(), - recv_condition_at_host->getResults()); + builder.create(while_op.getLoc(), + recv_condition_at_host->getResults()); if (failed(MoveToHostMultiCluster( device_cluster, &while_op.getCond().front(), @@ -1106,14 +1144,14 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, } return WalkResult::advance(); }); - if (result.wasInterrupted()) return failure(); - return success(); + if (result.wasInterrupted()) return mlir::failure(); + return mlir::success(); } // Removes outside compilation from all ops inside `host_launch_op`. Should // only be run after all outside compiled ops have been moved to // `host_launch_op`. -void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) { +void RemoveOutsideCompilation(mlir::tf_device::LaunchOp host_launch_op) { host_launch_op.GetBody().walk([&](Operation* op) { if (op->hasAttr(kXlaOutsideCompilationAttr)) { op->removeAttr( @@ -1129,14 +1167,16 @@ void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) { // if it is non replicated cluster and there is a device attr with some // non-empty device, then that device's ordinal (0 out of TPU:0 and // 1 out of TPU:1) is extracted and the default ordinal is set to this value. -LogicalResult GetDefaultDeviceOrdinal(tf_device::ClusterOp device_cluster, +LogicalResult GetDefaultDeviceOrdinal(mlir::tf_device::ClusterOp device_cluster, int& default_ordinal) { - bool has_replication = device_cluster->hasAttr(TF::kReplicationInfoAttr); + bool has_replication = + device_cluster->hasAttr(mlir::TF::kReplicationInfoAttr); std::string replication_info; if (has_replication) { replication_info = - device_cluster->getAttrOfType(TF::kReplicationInfoAttr) + device_cluster + ->getAttrOfType(mlir::TF::kReplicationInfoAttr) .str(); } if (replication_info == kNoReplicationCluster || replication_info.empty()) { @@ -1156,7 +1196,7 @@ LogicalResult GetDefaultDeviceOrdinal(tf_device::ClusterOp device_cluster, << " could not find ordinal for the given device"; } } - return success(); + return mlir::success(); } // The results of parallel executes is the combination of return values from @@ -1177,7 +1217,7 @@ llvm::SmallVector GetParallelExecuteResultsTypes( // Remap the device cluster results with parallel execute op results llvm::SmallVector GetRemappedTpuClusterResults( - tf_device::ClusterOp device_cluster, + mlir::tf_device::ClusterOp device_cluster, const llvm::SmallVector& return_value_from_host, const llvm::SmallVector& return_value_from_device) { llvm::SmallVector remapped_device_cluster_results; @@ -1187,7 +1227,7 @@ llvm::SmallVector GetRemappedTpuClusterResults( return_value_from_host.size() + return_value_from_device.size()); llvm::SmallDenseMap> return_operand_map; - auto return_op = llvm::dyn_cast( + auto return_op = llvm::dyn_cast( device_cluster.GetBody().getTerminator()); for (OpOperand& operand : return_op->getOpOperands()) { @@ -1221,8 +1261,8 @@ llvm::SmallVector GetRemappedTpuClusterResults( // Remap cluster results with parallel_execute results if user is outside of // parallel_execute. void RemapDeviceClusterResultsWithParallelExecuteResults( - tf_device::ClusterOp device_cluster, - tf_device::ParallelExecuteOp parallel_execute_op, + mlir::tf_device::ClusterOp device_cluster, + mlir::tf_device::ParallelExecuteOp parallel_execute_op, const llvm::SmallVector& return_value_from_host, const llvm::SmallVector& return_value_from_device) { llvm::SmallVector remapped_device_cluster_results = @@ -1261,7 +1301,7 @@ llvm::SmallVector GetNewDeviceTypes( } // Move ops in old device cluster to new device cluster -void MoveOldTpuClusterToNewTpuCluster(tf_device::ClusterOp device_cluster, +void MoveOldTpuClusterToNewTpuCluster(mlir::tf_device::ClusterOp device_cluster, Operation* after_op_r) { for (Operation& op : llvm::make_early_inc_range(device_cluster.GetBody())) { if (&op != device_cluster.GetBody().getTerminator()) { @@ -1271,7 +1311,7 @@ void MoveOldTpuClusterToNewTpuCluster(tf_device::ClusterOp device_cluster, } // Move ops in the tmp host launch op to new host launch op -void MoveTmpLaunchOpToNewLaunchOp(tf_device::LaunchOp tmp_host_launch_op, +void MoveTmpLaunchOpToNewLaunchOp(mlir::tf_device::LaunchOp tmp_host_launch_op, Operation* after_op_host_cluster) { for (Operation& op : llvm::make_early_inc_range(tmp_host_launch_op.GetBody())) { @@ -1285,10 +1325,10 @@ void MoveTmpLaunchOpToNewLaunchOp(tf_device::LaunchOp tmp_host_launch_op, // outside compiled ops, we can create the actual parallel_execute regions. // Still, one region is for the host computation for outside compilation and // the other one is for the original Device cluster computation. -tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( +mlir::tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( OpBuilder& builder, int num_regions, ArrayRef core_to_host, - tf_device::ClusterOp device_cluster, - ArrayRef core_to_tmp_host_launch, + mlir::tf_device::ClusterOp device_cluster, + ArrayRef core_to_tmp_host_launch, ArrayRef return_value_from_host, ArrayRef return_value_from_device) { llvm::SmallVector parallel_execute_result_types = @@ -1296,9 +1336,9 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( return_value_from_device); builder.setInsertionPoint(device_cluster); - auto parallel_execute_op = builder.create( + auto parallel_execute_op = builder.create( device_cluster.getLoc(), num_regions, parallel_execute_result_types); - SmallVector core_to_host_launch; + SmallVector core_to_host_launch; for (int core = 0; core < core_to_tmp_host_launch.size(); ++core) { Block& host_computation_block = parallel_execute_op.GetRegionBlockWithIndex(core); @@ -1313,14 +1353,14 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( llvm::SmallVector host_results; host_results.insert(host_results.end(), return_value_from_host.begin(), return_value_from_host.end()); - tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster( + mlir::tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster( builder, device_cluster, core_to_host[core], host_results); core_to_host_launch.push_back(host_launch_op); // Create a return op for host computation block builder.setInsertionPointToEnd(&host_computation_block); - builder.create(device_cluster.getLoc(), - host_launch_op->getResults()); + builder.create(device_cluster.getLoc(), + host_launch_op->getResults()); } // Move the launch body to last parallel_execute block. @@ -1337,7 +1377,7 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( // Create a empty device cluster op with same attribute but different return // type - auto new_device_cluster = builder.create( + auto new_device_cluster = builder.create( device_cluster.getLoc(), device_result_types, /*operands=*/llvm::ArrayRef{}, device_cluster->getAttrs()); @@ -1345,14 +1385,14 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( builder.setInsertionPointToEnd(&new_device_cluster.GetBody()); // Create return op for device computation region in the paralle_execute op - Operation* after_op_r = builder.create( + Operation* after_op_r = builder.create( new_device_cluster.getLoc(), device_results); builder.setInsertionPointToEnd(¶llel_execute_device_block); // Create return op for the new device cluster op - builder.create(device_cluster.getLoc(), - new_device_cluster.getResults()); + builder.create(device_cluster.getLoc(), + new_device_cluster.getResults()); MoveOldTpuClusterToNewTpuCluster(device_cluster, after_op_r); @@ -1371,8 +1411,8 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( // a region for `device_cluster` computation by extracting outside compiled ops // to host computation. LogicalResult CreateParallelExecuteForOutsideCompilation( - tf_device::ClusterOp device_cluster, - llvm::SmallVector& ops, + mlir::tf_device::ClusterOp device_cluster, + llvm::SmallVector& ops, std::optional& is_map_oc, ArrayRef core_to_host, bool has_tpu_device) { OpBuilder builder(device_cluster); @@ -1385,10 +1425,11 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( // `map_outside_compilation` case `num_host_regions == num_cores_per_replica`. const int num_host_regions = core_to_host.size(); const int num_regions = 1 + num_host_regions; - auto tmp_parallel_execute_op = builder.create( - device_cluster.getLoc(), num_regions, llvm::ArrayRef{}); + auto tmp_parallel_execute_op = + builder.create( + device_cluster.getLoc(), num_regions, llvm::ArrayRef{}); SmallVector core_to_host_insertion_point; - SmallVector core_to_tmp_launch; + SmallVector core_to_tmp_launch; SmallVector compilation_key_ops; SmallVector core_to_compilation_key; SmallVector core_to_device_ordinal_op; @@ -1399,13 +1440,14 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( builder.setInsertionPointToEnd(&tmp_host_computation_block); // Create a single tmp launch op for all outside compiled ops. llvm::SmallVector tmp_host_results; - tf_device::LaunchOp tmp_host_launch_op = CreateLaunchOpForOutsideCluster( - builder, device_cluster, core_to_host[core], tmp_host_results); + mlir::tf_device::LaunchOp tmp_host_launch_op = + CreateLaunchOpForOutsideCluster(builder, device_cluster, + core_to_host[core], tmp_host_results); core_to_tmp_launch.push_back(tmp_host_launch_op); // Create a tmp return op for tmp host computation block builder.setInsertionPointToEnd(&tmp_host_computation_block); - builder.create(device_cluster.getLoc(), - llvm::ArrayRef{}); + builder.create(device_cluster.getLoc(), + llvm::ArrayRef{}); core_to_host_insertion_point.push_back( tmp_host_launch_op.GetBody().getTerminator()); @@ -1418,16 +1460,17 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( compilation_key_op = CreateCompilationKeyPlaceholder(device_cluster.getLoc(), builder); compilation_key = - llvm::dyn_cast( + llvm::dyn_cast( compilation_key_op) .getProgram(); if (has_tpu_device) { - device_ordinal_op = builder.create( - device_cluster.getLoc(), - RankedTensorType::get({}, builder.getI64Type()), - builder.getI64IntegerAttr(core)); + device_ordinal_op = + builder.create( + device_cluster.getLoc(), + RankedTensorType::get({}, builder.getI64Type()), + builder.getI64IntegerAttr(core)); } else { - device_ordinal_op = builder.create( + device_ordinal_op = builder.create( device_cluster.getLoc(), DenseIntElementsAttr::get( RankedTensorType::get({}, builder.getI64Type()), @@ -1436,7 +1479,7 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( compilation_key_ops.push_back(compilation_key_op); core_to_compilation_key.push_back(compilation_key); core_to_device_ordinal_op.push_back(device_ordinal_op); - if (device_cluster->getParentOfType()) + if (device_cluster->getParentOfType()) core_to_device_ordinal.push_back( core_to_device_ordinal_op[core]->getResults()[0]); } @@ -1444,7 +1487,7 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( builder.setInsertionPoint(tmp_parallel_execute_op); int default_device_ordinal = 0; if (failed(GetDefaultDeviceOrdinal(device_cluster, default_device_ordinal))) { - return failure(); + return mlir::failure(); } // communication_key_index is part of the message identifier and is // incremented for each _XlaHostComputeMlir. @@ -1455,7 +1498,7 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( if (failed(DecomposeControlFlow( device_cluster, core_to_compilation_key, core_to_device_ordinal, default_device_ordinal, communication_key_index, is_map_oc))) - return failure(); + return mlir::failure(); // Move all outside compiled ops including control flow to tmp host launch. // Also set the values returned from the host when ops are moved. @@ -1465,7 +1508,7 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( core_to_device_ordinal, default_device_ordinal, /*control_above=*/false, is_map_oc, communication_key_index, &returns_from_host))) - return failure(); + return mlir::failure(); llvm::SmallVector returns_from_device; GetReturnValueFromDevice(device_cluster, returns_from_host, @@ -1477,10 +1520,10 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( if (communication_key_index == 0 || core_to_device_ordinal.empty()) for (auto op : core_to_device_ordinal_op) op->erase(); - for (tf_device::LaunchOp tmp_host_launch_op : core_to_tmp_launch) + for (mlir::tf_device::LaunchOp tmp_host_launch_op : core_to_tmp_launch) RemoveOutsideCompilation(tmp_host_launch_op); - tf_device::ParallelExecuteOp parallel_execute_op = + mlir::tf_device::ParallelExecuteOp parallel_execute_op = CreateFinalParallelExecuteOp(builder, num_regions, core_to_host, device_cluster, core_to_tmp_launch, returns_from_host, returns_from_device); @@ -1494,12 +1537,12 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( device_cluster.erase(); - return success(); + return mlir::success(); } // Check that cluster results are valid. An result is invalid when it does not // have a valid XLA type. -LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) { +LogicalResult CheckClusterResults(mlir::tf_device::ClusterOp cluster) { for (OpResult result : cluster.getResults()) { if (!tensorflow::TypeValidForXLA(result.getType())) { return cluster.emitError() @@ -1508,14 +1551,14 @@ LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) { << result.getType(); } } - return success(); + return mlir::success(); } // Check that op marked for outside compilation has an ancestor also marked for // outside compilation. LogicalResult CheckAncestorNotOutsideComp(Operation* op) { if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) - return success(); + return mlir::success(); Operation* iter_op = op; while (auto* parent_op = iter_op->getParentOp()) { if (parent_op->getAttrOfType(kXlaOutsideCompilationAttr)) { @@ -1526,7 +1569,7 @@ LogicalResult CheckAncestorNotOutsideComp(Operation* op) { } iter_op = parent_op; } - return success(); + return mlir::success(); } // Check the validity of the module, pre-pass. @@ -1535,18 +1578,18 @@ LogicalResult CheckPreconditions(ModuleOp module) { if (failed(CheckAncestorNotOutsideComp(op))) return WalkResult::interrupt(); return WalkResult::advance(); }); - if (walk_result.wasInterrupted()) return failure(); - return success(); + if (walk_result.wasInterrupted()) return mlir::failure(); + return mlir::success(); } // Check the validity of the module, post-pass. LogicalResult CheckPostconditions(ModuleOp module) { - auto walk_result = module.walk([&](tf_device::ClusterOp cluster) { + auto walk_result = module.walk([&](mlir::tf_device::ClusterOp cluster) { if (failed(CheckClusterResults(cluster))) return WalkResult::interrupt(); return WalkResult::advance(); }); - if (walk_result.wasInterrupted()) return failure(); - return success(); + if (walk_result.wasInterrupted()) return mlir::failure(); + return mlir::success(); } void ExtractOutsideCompilation::runOnOperation() { @@ -1558,10 +1601,11 @@ void ExtractOutsideCompilation::runOnOperation() { if (failed(tensorflow::GetDevicesFromOp(module, &devices))) return signalPassFailure(); - llvm::SmallVector tmp_parallel_execute_ops; + llvm::SmallVector + tmp_parallel_execute_ops; std::optional is_map_oc; - module.walk([&](tf_device::ClusterOp device_cluster) { + module.walk([&](mlir::tf_device::ClusterOp device_cluster) { if (HasOutsideCompilationNested(device_cluster.getOperation())) { SmallVector core_to_host; if (failed(tensorflow::GetDeviceToHostMap(device_cluster, core_to_host))) @@ -1594,5 +1638,6 @@ std::unique_ptr> CreateExtractOutsideCompilationPass() { return std::make_unique(); } -} // namespace TFDevice -} // namespace mlir +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc similarity index 70% rename from tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc rename to tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc index ca68e36e581443..6a38f620377cf2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,30 +19,61 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/strings/str_join.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Rewrite/PatternApplicator.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/core/lib/monitoring/gauge.h" -namespace mlir { -namespace TFDevice { +namespace tensorflow { +namespace tf2xla { +namespace internal { namespace { +using mlir::Block; +using mlir::BoolAttr; +using mlir::Dialect; +using mlir::LogicalResult; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::Operation; +using mlir::OperationName; +using mlir::OperationPass; +using mlir::Pattern; +using mlir::PatternApplicator; +using mlir::RewritePatternSet; +using mlir::StringAttr; +using mlir::TensorType; +using mlir::Type; +using mlir::Value; +using mlir::WalkResult; + constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement"; @@ -52,7 +83,7 @@ auto* auto_outside_compilation_gauge = "Tracks if auto outside compilation is enabled"); #define GEN_PASS_DEF_MARKOPSFOROUTSIDECOMPILATIONPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" struct MarkOpsForOutsideCompilation : public impl::MarkOpsForOutsideCompilationPassBase< @@ -79,16 +110,17 @@ void AddCanonicalizationPatterns(MLIRContext* context, void AddSupportedOpsUsingFolding(MLIRContext* context, llvm::DenseSet* supported_ops) { llvm::SmallDenseSet allowlist_ops = { - OperationName(TF::BroadcastArgsOp::getOperationName(), context), - OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context), - OperationName(TF::ConcatOffsetOp::getOperationName(), context), - OperationName(TF::EmptyOp::getOperationName(), context), - OperationName(TF::ListDiffOp::getOperationName(), context), - OperationName(TF::RankOp::getOperationName(), context), - OperationName(TF::RangeOp::getOperationName(), context), - OperationName(TF::ShapeOp::getOperationName(), context), - OperationName(TF::ShapeNOp::getOperationName(), context), - OperationName(TF::SizeOp::getOperationName(), context), + OperationName(mlir::TF::BroadcastArgsOp::getOperationName(), context), + OperationName(mlir::TF::BroadcastGradientArgsOp::getOperationName(), + context), + OperationName(mlir::TF::ConcatOffsetOp::getOperationName(), context), + OperationName(mlir::TF::EmptyOp::getOperationName(), context), + OperationName(mlir::TF::ListDiffOp::getOperationName(), context), + OperationName(mlir::TF::RankOp::getOperationName(), context), + OperationName(mlir::TF::RangeOp::getOperationName(), context), + OperationName(mlir::TF::ShapeOp::getOperationName(), context), + OperationName(mlir::TF::ShapeNOp::getOperationName(), context), + OperationName(mlir::TF::SizeOp::getOperationName(), context), }; supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end()); @@ -102,14 +134,16 @@ void AddSupportedOpsUsingFolding(MLIRContext* context, void AddOldBridgeOnlyOps(MLIRContext* context, llvm::DenseSet* supported_ops) { llvm::SmallDenseSet allowlist_ops = { - OperationName(TF::DynamicPartitionOp::getOperationName(), context), - OperationName(TF::OutfeedEnqueueOp::getOperationName(), context), - OperationName(TF::WhereOp::getOperationName(), context), - OperationName(TF::UniqueOp::getOperationName(), context), - OperationName(TF::XlaSetDynamicDimensionSizeOp::getOperationName(), + OperationName(mlir::TF::DynamicPartitionOp::getOperationName(), context), + OperationName(mlir::TF::OutfeedEnqueueOp::getOperationName(), context), + OperationName(mlir::TF::WhereOp::getOperationName(), context), + OperationName(mlir::TF::UniqueOp::getOperationName(), context), + OperationName(mlir::TF::XlaSetDynamicDimensionSizeOp::getOperationName(), + context), + OperationName(mlir::TF::XlaSpmdFullToShardShapeOp::getOperationName(), + context), + OperationName(mlir::TF::XlaSpmdShardToFullShapeOp::getOperationName(), context), - OperationName(TF::XlaSpmdFullToShardShapeOp::getOperationName(), context), - OperationName(TF::XlaSpmdShardToFullShapeOp::getOperationName(), context), }; supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end()); @@ -120,46 +154,46 @@ void AddOldBridgeOnlyOps(MLIRContext* context, void AddSupportedFunctionalOps(MLIRContext* context, llvm::DenseSet* supported_ops) { supported_ops->insert( - OperationName(TF::CaseRegionOp::getOperationName(), context)); - supported_ops->insert( - OperationName(TF::IfRegionOp::getOperationName(), context)); - supported_ops->insert( - OperationName(TF::InplaceAddOp::getOperationName(), context)); - supported_ops->insert( - OperationName(TF::WhileRegionOp::getOperationName(), context)); - supported_ops->insert( - OperationName(TF::XlaCallModuleOp::getOperationName(), context)); + OperationName(mlir::TF::CaseRegionOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaReduceOp::getOperationName(), context)); + OperationName(mlir::TF::IfRegionOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaReduceWindowOp::getOperationName(), context)); + OperationName(mlir::TF::InplaceAddOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context)); + OperationName(mlir::TF::WhileRegionOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaScatterOp::getOperationName(), context)); + OperationName(mlir::TF::XlaCallModuleOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaSelectAndScatterOp::getOperationName(), context)); + OperationName(mlir::TF::XlaReduceOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::SymbolicGradientOp::getOperationName(), context)); + OperationName(mlir::TF::XlaReduceWindowOp::getOperationName(), context)); + supported_ops->insert(OperationName( + mlir::TF::XlaRngBitGeneratorOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaVariadicReduceOp::getOperationName(), context)); + OperationName(mlir::TF::XlaScatterOp::getOperationName(), context)); + supported_ops->insert(OperationName( + mlir::TF::XlaSelectAndScatterOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaVariadicReduceV2Op::getOperationName(), context)); + OperationName(mlir::TF::SymbolicGradientOp::getOperationName(), context)); + supported_ops->insert(OperationName( + mlir::TF::XlaVariadicReduceOp::getOperationName(), context)); + supported_ops->insert(OperationName( + mlir::TF::XlaVariadicReduceV2Op::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaVariadicSortOp::getOperationName(), context)); + OperationName(mlir::TF::XlaVariadicSortOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::XlaReplicaIdOp::getOperationName(), context)); + OperationName(mlir::TF::XlaReplicaIdOp::getOperationName(), context)); supported_ops->insert( - OperationName(TF::YieldOp::getOperationName(), context)); + OperationName(mlir::TF::YieldOp::getOperationName(), context)); } // These embedding ops are rewritten when running TPUCompileOp. void AddRewrittenEmbeddingOps(MLIRContext* context, llvm::DenseSet* supported_ops) { supported_ops->insert(OperationName( - TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context)); + mlir::TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context)); supported_ops->insert(OperationName( - TF::SendTPUEmbeddingGradientsOp::getOperationName(), context)); + mlir::TF::SendTPUEmbeddingGradientsOp::getOperationName(), context)); } // Stack, TensorList and TensorArray ops are rewritten during the second phase @@ -171,32 +205,32 @@ void AddRewrittenCompositeOps(MLIRContext* context, #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context) llvm::SmallDenseSet allowlist_ops = { // Stack ops. - GET_OPERATION_NAME(TF::StackV2Op), - GET_OPERATION_NAME(TF::StackPushV2Op), - GET_OPERATION_NAME(TF::StackPopV2Op), + GET_OPERATION_NAME(mlir::TF::StackV2Op), + GET_OPERATION_NAME(mlir::TF::StackPushV2Op), + GET_OPERATION_NAME(mlir::TF::StackPopV2Op), // Tensor Array ops. - GET_OPERATION_NAME(TF::TensorArrayV3Op), - GET_OPERATION_NAME(TF::TensorArrayReadV3Op), - GET_OPERATION_NAME(TF::TensorArrayWriteV3Op), - GET_OPERATION_NAME(TF::TensorArrayConcatV3Op), - GET_OPERATION_NAME(TF::TensorArraySplitV3Op), - GET_OPERATION_NAME(TF::TensorArraySizeV3Op), - GET_OPERATION_NAME(TF::TensorArrayGradV3Op), - GET_OPERATION_NAME(TF::TensorArrayGatherV3Op), - GET_OPERATION_NAME(TF::TensorArrayScatterV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayReadV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayWriteV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayConcatV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArraySplitV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArraySizeV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayGradV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayGatherV3Op), + GET_OPERATION_NAME(mlir::TF::TensorArrayScatterV3Op), // Tensor List Ops. - GET_OPERATION_NAME(TF::EmptyTensorListOp), - GET_OPERATION_NAME(TF::TensorListReserveOp), - GET_OPERATION_NAME(TF::TensorListFromTensorOp), - GET_OPERATION_NAME(TF::TensorListPushBackOp), - GET_OPERATION_NAME(TF::TensorListPopBackOp), - GET_OPERATION_NAME(TF::TensorListGetItemOp), - GET_OPERATION_NAME(TF::TensorListSetItemOp), - GET_OPERATION_NAME(TF::TensorListLengthOp), - GET_OPERATION_NAME(TF::TensorListElementShapeOp), - GET_OPERATION_NAME(TF::TensorListGatherOp), - GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp), - GET_OPERATION_NAME(TF::TensorListStackOp), + GET_OPERATION_NAME(mlir::TF::EmptyTensorListOp), + GET_OPERATION_NAME(mlir::TF::TensorListReserveOp), + GET_OPERATION_NAME(mlir::TF::TensorListFromTensorOp), + GET_OPERATION_NAME(mlir::TF::TensorListPushBackOp), + GET_OPERATION_NAME(mlir::TF::TensorListPopBackOp), + GET_OPERATION_NAME(mlir::TF::TensorListGetItemOp), + GET_OPERATION_NAME(mlir::TF::TensorListSetItemOp), + GET_OPERATION_NAME(mlir::TF::TensorListLengthOp), + GET_OPERATION_NAME(mlir::TF::TensorListElementShapeOp), + GET_OPERATION_NAME(mlir::TF::TensorListGatherOp), + GET_OPERATION_NAME(mlir::TF::TensorListScatterIntoExistingListOp), + GET_OPERATION_NAME(mlir::TF::TensorListStackOp), }; #undef GET_OPERATION_NAME @@ -204,13 +238,13 @@ void AddRewrittenCompositeOps(MLIRContext* context, } bool IsStringType(Type type) { - if (type.isa()) return true; + if (type.isa()) return true; - auto sub_type = type.dyn_cast(); + auto sub_type = type.dyn_cast(); if (!sub_type) return false; bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) { - return type.getElementType().isa(); + return type.getElementType().isa(); }); return has_string; } @@ -241,11 +275,10 @@ bool MatchesPattern(Operation& op, bool IsSupportedOp(Operation& op, const llvm::DenseSet& supported_ops, const Dialect* tf_dialect) { - if (op.getDialect() != tf_dialect) - return true; + if (op.getDialect() != tf_dialect) return true; // Assert has a legalization that later removes it so we don't want to outside // compile it ever for performance reasons. - if (llvm::isa(op)) return true; + if (llvm::isa(op)) return true; if (HasStringOperand(op)) return false; if (HasStringResult(op)) return false; @@ -253,25 +286,11 @@ bool IsSupportedOp(Operation& op, auto abstractOp = op.getRegisteredInfo(); if (!abstractOp) return false; - return mhlo::HasTf2XlaFallback(abstractOp->getTypeID()); -} - -// Checks all regions of `op` for captured string operands. -bool HasCapturedStringOperand(Operation* op) { - bool string_operand = false; - for (auto& region : op->getRegions()) { - mlir::visitUsedValuesDefinedAbove( - region, region, [&](mlir::OpOperand* operand) { - if (getElementTypeOrSelf(operand->get()).isa()) - string_operand = true; - }); - if (string_operand) return string_operand; - } - return string_operand; + return mlir::mhlo::HasTf2XlaFallback(abstractOp->getTypeID()); } bool IsVariant(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return getElementTypeOrSelf(value.getType()).isa(); } bool HasOutsideCompiledAncestor(Operation* op) { @@ -287,7 +306,7 @@ bool HasOutsideCompiledAncestor(Operation* op) { // If any tf.variants are inputs/outputs to the another outside compiled // Operation, `op`, mark them for outside compilation unless they are already // marks with outside compilation attribute. -void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) { +void MarkVariantInputsOutputs(mlir::tf_device::ClusterOp tpu_cluster) { std::queue outside_compiled_ops; tpu_cluster.walk([&](Operation* op) { if (op->hasAttrOfType(kXlaOutsideCompilationAttr)) @@ -316,7 +335,7 @@ void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) { for (auto value : op->getResults()) { if (IsVariant(value)) { for (auto user : value.getUsers()) { - if (!user->hasTrait() && + if (!user->hasTrait() && !HasOutsideCompiledAncestor(user) && !user->getAttrOfType(kXlaOutsideCompilationAttr)) { user->setAttr(kXlaOutsideCompilationAttr, @@ -358,7 +377,7 @@ LogicalResult MarkUncompilableOps( if (outside_compiled_cluster_counter > 0) { auto_outside_compilation_gauge->GetCell()->Set(true); } - return success(); + return mlir::success(); } // Check for uncompilable ops that are in `tf_dialect` and are not already @@ -369,7 +388,7 @@ bool ContainsUncompilableOps(const Dialect* tf_dialect, Block* block, // Check if op or any parent is already marked for outside compilation. block->walk([&](Operation* op) { Operation* iter_op = op; - while (iter_op && !llvm::isa(iter_op)) { + while (iter_op && !llvm::isa(iter_op)) { if (iter_op->hasAttrOfType(kXlaOutsideCompilationAttr)) { return; } @@ -444,9 +463,9 @@ void MarkOpsForOutsideCompilation::runOnOperation() { return signalPassFailure(); } RewritePatternSet patterns(&getContext()); - mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); - TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns); - TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns); + mlir::mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); + mlir::TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns); + mlir::TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns); AddCanonicalizationPatterns(module.getContext(), &patterns); // `supported_ops` contains the name of all of the ops that can potentially be @@ -465,7 +484,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() { AddRewrittenEmbeddingOps(module.getContext(), &supported_ops); AddRewrittenCompositeOps(module.getContext(), &supported_ops); - auto result = module.walk([&](tf_device::ClusterOp cluster) { + auto result = module.walk([&](mlir::tf_device::ClusterOp cluster) { // Only if `allow_soft_placement` attribute is true should we mark ops // for outside compilation. auto soft_placement_attr = @@ -498,5 +517,6 @@ CreateMarkOpsForOutsideCompilationPass() { return std::make_unique(); } -} // namespace TFDevice -} // namespace mlir +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h similarity index 76% rename from tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.h rename to tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h index 74247860fcd36e..4e28930b3c1f8e 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h @@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_DIALECT_TO_EXECUTOR_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_DIALECT_TO_EXECUTOR_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_MLIR_TO_GRAPH_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_MLIR_TO_GRAPH_PASSES_H_ #include @@ -25,8 +25,11 @@ namespace internal { std::unique_ptr> CreateVerifyInputDialectToExecutorPass(); +#define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_VERIFYINPUTDIALECTTOEXECUTORPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h.inc" + } // namespace internal } // namespace tf2xla } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_DIALECT_TO_EXECUTOR_PASSES_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_MLIR_TO_GRAPH_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.td similarity index 90% rename from tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.td rename to tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.td index 9c7891daa84c6b..a8796805753144 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/dialect_to_executor_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.td @@ -1,4 +1,3 @@ - /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,10 +11,10 @@ limitations under the License. ==============================================================================*/ include "mlir/Pass/PassBase.td" -def VerifyInputDialectToExecutor : Pass<"verify-input-dialect-to-executor-pass", "mlir::func::FuncOp"> { +def VerifyInputDialectToExecutorPass : Pass<"verify-input-dialect-to-executor-pass", "mlir::func::FuncOp"> { let summary = "Verify that TF dialect to executor converter receives the correct input."; let description = [{ Verifies the input before exporting to TF executor. This includes checking whether the Ops are in TF functional, have device attributes & there are no tf_device.cluster_func ops. }]; let constructor = "tensorflow::tf2xla::internal::CreateVerifyInputDialectToExecutorPass()"; -} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc index 235a7ca1ec5468..1cf9115d9572a2 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc @@ -14,13 +14,15 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h" namespace tensorflow { namespace tf2xla { @@ -31,6 +33,9 @@ namespace { #define GEN_PASS_DEF_VERIFYCLUSTERINGPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" +using mlir::Operation; +using mlir::WalkResult; + class VerifyClusteringPass : public impl::VerifyClusteringPassBase { public: @@ -38,20 +43,26 @@ class VerifyClusteringPass }; void VerifyClusteringPass::runOnOperation() { - std::set valid_namespaces = {"tf", "func", "return", "tf_device", - "builtin"}; - mlir::Operation* func_op = getOperation(); + Operation* func_op = getOperation(); - auto walk_result = func_op->walk([&](mlir::Operation* op) { - if (valid_namespaces.find(op->getDialect()->getNamespace().str()) == - valid_namespaces.end()) { + auto walk_result = func_op->walk([&](Operation* op) { + if (!tensorflow::tf2xla::internal::IsInBridgeAcceptableDialects(op)) { std::string error = "op is in dialect " + op->getDialect()->getNamespace().str() + " not in tf functional dialect"; op->emitError() << error; + return WalkResult::interrupt(); + } + + if (op->hasAttr(mlir::TF::kXlaOutsideCompilationAttr)) { + std::string error = + "op has outside compilation attribute _xla_outside_compilation which " + "is not allowed after clustering"; + op->emitError() << error; return mlir::WalkResult::interrupt(); } - return mlir::WalkResult::advance(); + + return WalkResult::advance(); }); if (walk_result.wasInterrupted()) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir index 23e60242621f37..7ba98798c126df 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir @@ -13,4 +13,14 @@ func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32> func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> { %0 = "tf.Identity"(%arg0) : (tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> func.return %0 : tensor<4x2x!tf_type.string> -} \ No newline at end of file +} + + +// ----- + +func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> { + // expected-error@below {{op has outside compilation attribute _xla_outside_compilation which is not allowed after clustering}} + %0 = "tf.Identity"(%arg0) {_xla_outside_compilation = "cluster1"}: (tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> + func.return %0 : tensor<4x2x!tf_type.string> +} + diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc new file mode 100644 index 00000000000000..53c1e5bab16ad0 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass.cc @@ -0,0 +1,84 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { +using mlir::Operation; +using mlir::OperationPass; +using mlir::WalkResult; +using mlir::func::FuncOp; + +#define GEN_PASS_DEF_VERIFYINPUTDIALECTTOEXECUTORPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h.inc" + +class VerifyInputDialectToExecutorPass + : public impl::VerifyInputDialectToExecutorPassBase< + VerifyInputDialectToExecutorPass> { + public: + void runOnOperation() override; +}; + +bool IsTfDeviceClusterFuncOp(Operation* op) { + std::string kClusterFuncOpName = "tf_device.cluster_func"; + return op->getName().getStringRef().str() == kClusterFuncOpName; +} + +void VerifyInputDialectToExecutorPass::runOnOperation() { + Operation* func_op = getOperation(); + + auto walk_result = func_op->walk([&](Operation* op) { + if (!tensorflow::tf2xla::internal::IsInBridgeAcceptableDialects(op)) { + std::string error = "op is in dialect " + + op->getDialect()->getNamespace().str() + + " which is not an accepted dialect"; + op->emitError() << error; + return WalkResult::interrupt(); + } + + if (IsTfDeviceClusterFuncOp(op)) { + std::string error = + "failed TF functional to executor validation, op " + "tf_device.cluster_func is not allowed"; + op->emitError() << error; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + + if (walk_result.wasInterrupted()) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateVerifyInputDialectToExecutorPass() { + return std::make_unique(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir new file mode 100644 index 00000000000000..5a6fda697d23fa --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir @@ -0,0 +1,34 @@ +// RUN: tf-opt -verify-input-dialect-to-executor-pass -split-input-file -verify-diagnostics %s | FileCheck %s +// Tests the VerifyClusteringPass Pass, ensures that an error is thrown when validation fails. + +// ----- + +// CHECK-LABEL: func @testNoClusterFuncOpPasses +func.func @testNoClusterFuncOpPasses(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> { + %0 = "tf.Identity"(%arg0) : (tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> + func.return %0 : tensor<4x2x!tf_type.string> +} + +// ----- + +func.func @testClusterFuncOpFails(%arg0: tensor) -> tensor { + // expected-error@below {{failed TF functional to executor validation, op tf_device.cluster_func is not allowed}} + %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor) -> tensor + func.return %cluster : tensor +} + +// ----- + +// CHECK-LABEL: func @testTFDialect +func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> { + %0 = "tf.Identity"(%arg0) : (tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!tf_type.string> + func.return %0 : tensor<4x2x!tf_type.string> +} + +// ----- + +func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // expected-error@below {{op is in dialect chlo which is not an accepted dialect}} + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_cluster_formation.cc similarity index 62% rename from tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc rename to tensorflow/compiler/mlir/tf2xla/internal/passes/xla_cluster_formation.cc index f99b754074f568..cbedce815b8229 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_cluster_formation.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,23 +14,47 @@ limitations under the License. ==============================================================================*/ #include -#include +#include +#include #include +#include "absl/strings/str_cat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h" #include "tensorflow/core/common_runtime/inline_function_utils.h" -namespace mlir { - -namespace { +namespace tensorflow { +namespace tf2xla { +namespace internal { + +using mlir::Block; +using mlir::CallInterfaceCallable; +using mlir::CallOpInterface; +using mlir::ModuleOp; +using mlir::OpBuilder; +using mlir::Operation; +using mlir::OperationPass; +using mlir::SymbolTable; +using mlir::SymbolTableCollection; +using mlir::SymbolUserOpInterface; +using mlir::func::FuncOp; #define GEN_PASS_DEF_XLACLUSTERFORMATIONPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement"; @@ -47,22 +71,22 @@ void CopyAttribute(const llvm::StringRef attr, Operation *src, } } -std::string getClusterOutlinedFunctionName(func::FuncOp func) { +std::string getClusterOutlinedFunctionName(FuncOp func) { return func.getSymName().str() + "_cluster_func"; } -void AddClusterAttributes(OpBuilder &builder, func::FuncOp entry_func, - tf_device::ClusterOp cluster) { - TF::CopyDeviceAndUnderscoredAttributes(entry_func, cluster); +void AddClusterAttributes(OpBuilder &builder, FuncOp entry_func, + mlir::tf_device::ClusterOp cluster) { + mlir::TF::CopyDeviceAndUnderscoredAttributes(entry_func, cluster); CopyAttribute(kAllowSoftPlacementAttr, entry_func, cluster); cluster->setAttr( - TF::kClusterOutlinedFunctionNameAttr, + mlir::TF::kClusterOutlinedFunctionNameAttr, builder.getStringAttr(getClusterOutlinedFunctionName(entry_func))); } // Wrap the body of `func` in a device cluster. `func` must have a single // region and a single block. -LogicalResult EncapsulateEntryFunctionBody(func::FuncOp entry_func) { +mlir::LogicalResult EncapsulateEntryFunctionBody(FuncOp entry_func) { // We've verified the input graph has single-entry and single-block entry // functions. This is just in case passes in the pipeline uninteionally break // the assumption, and not expected to happen in practice. @@ -70,7 +94,7 @@ LogicalResult EncapsulateEntryFunctionBody(func::FuncOp entry_func) { entry_func->emitError() << "TF2XLA MLIR CPU/GPU MLIR phase 1 bridge " "expects single region and single " "block in an entry function."; - return failure(); + return mlir::failure(); } std::vector ops_without_terminator; for (auto &op : entry_func.front().without_terminator()) { @@ -79,36 +103,39 @@ LogicalResult EncapsulateEntryFunctionBody(func::FuncOp entry_func) { Operation *original_return_op = entry_func.front().getTerminator(); OpBuilder builder(entry_func.getContext()); builder.setInsertionPointToEnd(&entry_func.front()); - auto cluster = builder.create( + auto cluster = builder.create( entry_func.getLoc(), entry_func.getResultTypes()); cluster.getBody().push_back(new Block); for (auto &op : ops_without_terminator) { op->moveBefore(&cluster.GetBody(), cluster.GetBody().end()); } builder.setInsertionPointToEnd(&cluster.GetBody()); - builder.create(original_return_op->getLoc(), - original_return_op->getResultTypes(), - original_return_op->getOperands()); + builder.create( + original_return_op->getLoc(), original_return_op->getResultTypes(), + original_return_op->getOperands()); original_return_op->erase(); builder.setInsertionPointToEnd(&entry_func.front()); - builder.create(entry_func->getLoc(), cluster->getResults()); + builder.create(entry_func->getLoc(), + cluster->getResults()); AddClusterAttributes(builder, entry_func, cluster); - return success(); + return mlir::success(); } -void EncapsulatePartitionedCall(Operation *call_op, StringAttr callee_name) { +void EncapsulatePartitionedCall(Operation *call_op, + mlir::StringAttr callee_name) { OpBuilder builder(call_op); - auto cluster = builder.create( + auto cluster = builder.create( call_op->getLoc(), call_op->getResultTypes()); cluster.getBody().push_back(new Block); call_op->replaceAllUsesWith(cluster.getResults()); call_op->moveBefore(&cluster.GetBody(), cluster.GetBody().end()); builder.setInsertionPointToEnd(&cluster.GetBody()); - builder.create(call_op->getLoc(), call_op->getResults()); + builder.create(call_op->getLoc(), + call_op->getResults()); // Propagate necessary attributes to the cluster so that when it's outlined, // the function will have correct attributes. - TF::CopyDeviceAndUnderscoredAttributes(call_op, cluster); - cluster->setAttr(TF::kClusterOutlinedFunctionNameAttr, callee_name); + mlir::TF::CopyDeviceAndUnderscoredAttributes(call_op, cluster); + cluster->setAttr(mlir::TF::kClusterOutlinedFunctionNameAttr, callee_name); cluster->setAttr(kAllowSoftPlacementAttr, builder.getBoolAttr(true)); } @@ -116,30 +143,31 @@ void EncapsulatePartitionedCall(Operation *call_op, StringAttr callee_name) { // `func` and is with compilation markers in a device cluster. For nested calls, // if the outermost one has the markers, encapsulates the outermost call and // returns. Otherwise, we'll keep going through inner calls until we found one. -LogicalResult EncapsulateFirstXlaCompilablePartitionedCalls( - func::FuncOp func, SymbolTableCollection &symbol_table_collection, +mlir::LogicalResult EncapsulateFirstXlaCompilablePartitionedCalls( + FuncOp func, SymbolTableCollection &symbol_table_collection, SymbolTable &symtab) { auto has_no_compile_device_type = [](SymbolUserOpInterface op) { - return !op->hasAttr(TF::kCompileDeviceTypeAttr); + return !op->hasAttr(mlir::TF::kCompileDeviceTypeAttr); }; mlir::OpBuilder builder(func.getContext()); auto noinline_attr_name = absl::StrCat("tf.", tensorflow::kNoInlineAttr); llvm::SmallVector noinline_pcall_ops, outermost_pcall_ops; - if (failed(GetOpsOfTypeUntilMiss( - func, symtab, /*predicate*/ has_no_compile_device_type, - /*hits*/ noinline_pcall_ops, - /*first_misses*/ outermost_pcall_ops))) { - return failure(); + if (mlir::failed( + mlir::GetOpsOfTypeUntilMiss( + func, symtab, /*predicate*/ has_no_compile_device_type, + /*hits*/ noinline_pcall_ops, + /*first_misses*/ outermost_pcall_ops))) { + return mlir::failure(); } // Cluster outermost partitioned calls with _xla_compile_device_type // attribute. for (auto &pcall_op : outermost_pcall_ops) { auto call = llvm::cast(pcall_op.getOperation()); CallInterfaceCallable callable = call.getCallableForCallee(); - auto sym = callable.get(); + auto sym = callable.get(); EncapsulatePartitionedCall(pcall_op, sym.getRootReference()); } // Partitioned calls are executed asynchronous. The calls outside of @@ -147,20 +175,20 @@ LogicalResult EncapsulateFirstXlaCompilablePartitionedCalls( // performance. for (auto &pcall_op : noinline_pcall_ops) { auto call = llvm::cast(pcall_op.getOperation()); - auto callee = llvm::cast( - call.resolveCallable(&symbol_table_collection)); + auto callee = + llvm::cast(call.resolveCallable(&symbol_table_collection)); callee->setAttr(noinline_attr_name, builder.getBoolAttr(true)); } - return success(); + return mlir::success(); } void XlaClusterFormationPass::runOnOperation() { ModuleOp module = getOperation(); SymbolTableCollection symbol_table_collection; SymbolTable symtab = symbol_table_collection.getSymbolTable(module); - llvm::SmallVector entry_funcs = GetEntryFunctions(module); + llvm::SmallVector entry_funcs = GetEntryFunctions(module); for (auto &entry_func : entry_funcs) { - if (entry_func->hasAttr(TF::kCompileDeviceTypeAttr)) { + if (entry_func->hasAttr(mlir::TF::kCompileDeviceTypeAttr)) { if (EncapsulateEntryFunctionBody(entry_func).failed()) { return signalPassFailure(); } @@ -172,12 +200,10 @@ void XlaClusterFormationPass::runOnOperation() { } } -} // namespace - -namespace TFDevice { std::unique_ptr> CreateXlaClusterFormationPass() { return std::make_unique(); } -} // namespace TFDevice -} // namespace mlir +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD new file mode 100644 index 00000000000000..a67178be9d770a --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD @@ -0,0 +1,45 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "dialect_detection_utils", + srcs = [ + "dialect_detection_utils.cc", + ], + hdrs = [ + "dialect_detection_utils.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:framework", + "//tensorflow/core/transforms/toposort:Pass", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +tf_cc_test( + name = "dialect_detection_utils_test", + srcs = ["dialect_detection_utils_test.cc"], + deps = [ + ":dialect_detection_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@stablehlo//:chlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.cc b/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.cc new file mode 100644 index 00000000000000..fe37304826416f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.cc @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h" + +#include +#include + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +bool IsInBridgeAcceptableDialects(mlir::Operation* op) { + const std::set kBuiltinNamespaces = {"func", "return", + "builtin"}; + const std::set kBridgeAcceptableNamespaces = {"tf", "tf_device"}; + bool isInDefaulNamespaces = + kBuiltinNamespaces.find(op->getDialect()->getNamespace().str()) != + kBuiltinNamespaces.end(); + bool isInBridgeAcceptableNamespaces = + kBridgeAcceptableNamespaces.find( + op->getDialect()->getNamespace().str()) != + kBridgeAcceptableNamespaces.end(); + return isInDefaulNamespaces || isInBridgeAcceptableNamespaces; +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h similarity index 53% rename from tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor.cc rename to tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h index dd78c065371d6b..6dd9851f7507bf 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h @@ -1,44 +1,33 @@ /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_DIALECT_DETECTION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_DIALECT_DETECTION_UTILS_H_ + +#include "mlir/IR/Operation.h" // from @llvm-project namespace tensorflow { namespace tf2xla { namespace internal { -namespace { - -#define GEN_PASS_DEF_VERIFYINPUTDIALECTTOEXECUTORPASS -#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" - -class VerifyInputDialectToexecutorPass - : public impl::VerifyInputDialectToexecutorPassBase< - VerifyInputDialectToexecutorPass> { - public: - void runOnOperation() override; -}; - -void VerifyInputDialectToexecutorPass::runOnOperation() {} - -} // namespace - -std::unique_ptr> -CreateVerifyInputDialectToExecutorPass() { - return std::make_unique(); -} +// Returns true if the op has a valid namespace during clustering & tf dialect +// to executor components of the Bridge. +bool IsInBridgeAcceptableDialects(mlir::Operation* op); } // namespace internal } // namespace tf2xla } // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_DIALECT_DETECTION_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils_test.cc new file mode 100644 index 00000000000000..b6a56d70290ceb --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h" + +#include +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +using mlir::MLIRContext; +using mlir::OpBuilder; +using mlir::Operation; +using mlir::OperationState; +using mlir::UnknownLoc; +using mlir::chlo::ChloDialect; +using mlir::TF::TensorFlowDialect; +using tensorflow::tf2xla::internal::IsInBridgeAcceptableDialects; + +class SharedUtilsTest : public ::testing::Test {}; + +TEST_F(SharedUtilsTest, IsInFunctionalDialectPasses) { + MLIRContext context; + context.loadDialect(); + OpBuilder opBuilder(&context); + OperationState state(UnknownLoc::get(opBuilder.getContext()), + /*OperationName=*/"tf.Const"); + mlir::Operation* op = Operation::create(state); + + bool result = IsInBridgeAcceptableDialects(op); + + EXPECT_TRUE(result); + op->destroy(); +} + +TEST_F(SharedUtilsTest, IsInFunctionalDialectFails) { + MLIRContext context; + context.loadDialect(); + OpBuilder opBuilder(&context); + OperationState state(UnknownLoc::get(opBuilder.getContext()), + /*OperationName=*/"chlo.broadcast_add"); + Operation* op = Operation::create(state); + + bool result = IsInBridgeAcceptableDialects(op); + + EXPECT_FALSE(result); + op->destroy(); +} + +} // namespace +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir index f6e3ca10f5a279..56620e66870520 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir @@ -51,7 +51,7 @@ func.func @batchmatmulv2(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> // SUPPORTED_FALLBACK_DEVICE: mhlo.dot_general // SUPPORTED_FALLBACK_DEVICE: mhlo.transpose - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> func.return %0 : tensor<3x4x4xf32> } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index a732c6d61281ca..b8552d1b6bdd10 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -524,7 +524,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: mhlo.reduce // CHECK: mhlo.dot_general // CHECK: mhlo.transpose - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> func.return %0 : tensor<3x4x4xf32> } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index ed0429ad242c94..0f0b1182e50bb7 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -466,6 +466,7 @@ cc_library( hdrs = ["legalization_op_config.h"], visibility = [ "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", ], deps = [ "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index 6e5f8285a0a928..979ec3f97e629e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -336,6 +336,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -481,6 +482,7 @@ bool IsOpTypeAllowedTf2XlaPreferred(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index ade2b5faa73c8a..7084f98b28568e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -131,8 +131,8 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 315); - EXPECT_EQ(non_categorized_count, 422); + EXPECT_EQ(tf2xla_fallback_count, 316); + EXPECT_EQ(non_categorized_count, 421); } // Just a counter test to see which ops have duplicate lowerings. This isn't a diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 70a32c7a270049..763e94734f6d01 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -501,8 +501,8 @@ Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) { // return the first element. Otherwise, `mhlo.get_tuple_element` users are // simply updated with `replacement`, and all other users are updated with a // slice of `replacement`. -void ReplaceWithTupleResult(OpBuilder& builder, ArrayRef values, - ArrayRef replacements, bool flatten_tuple) { +void ReplaceWithTupleResult(OpBuilder& builder, ValueRange values, + ValueRange replacements, bool flatten_tuple) { if (flatten_tuple) { for (size_t result_index = 0; result_index < values.size(); result_index++) values[result_index].replaceAllUsesWith(replacements[result_index]); @@ -547,10 +547,8 @@ Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block, block.addArguments( types, SmallVector(types.size(), block.getParent()->getLoc())); - auto old_args = ArrayRef(block.getArguments().begin(), - block.getArguments().begin() + old_args_size); - auto new_args = ArrayRef(block.getArguments().begin() + old_args_size, - block.getArguments().end()); + ValueRange old_args = block.getArguments().take_front(old_args_size); + ValueRange new_args = block.getArguments().drop_front(old_args_size); assert(!new_args.empty()); ReplaceWithTupleResult(builder, old_args, new_args, /*flatten_tuple=*/true); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 3aad616f162b17..0ee5d1dee5925d 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -392,7 +392,7 @@ foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in def GetPrecisionConfig: NativeCodeCall< "GetPrecisionConfig(&$_builder)">; -def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b), +def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), (MHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index f803230ea4f504..be8298824029dd 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tosa/tf_passes.h" #include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" @@ -56,7 +57,8 @@ int main(int argc, char **argv) { mlir::mhlo::registerLegalizeTfPasses(); mlir::mhlo::registerTfXlaPasses(); mlir::quant::stablehlo::registerBridgePasses(); - tensorflow::tf2xla::internal::registerTFXLABridgePasses(); + tensorflow::tf2xla::internal::registerTFXLABridgeClusteringPasses(); + tensorflow::tf2xla::internal::registerTFXLABridgeMlirToGraphPasses(); mlir::tosa::registerLegalizeTosaPasses(); mlir::tosa::registerTFtoTOSALegalizationPipeline(); mlir::tosa::registerTFLtoTOSALegalizationPipeline(); diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index b5d2bf7d9933b9..fade4b23bf70ea 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -89,8 +89,7 @@ class TFRInlinerInterface : public DialectInlinerInterface { // Handle the given inlined terminator by replacing it with a new operation // as necessary. Required when the region has only one block. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { auto retValOp = dyn_cast(op); if (!retValOp) return; diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h index c862f0f1b5f983..e0e24f4aca8f2d 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h @@ -102,18 +102,21 @@ class TFRTypeImpl : public Type::TypeBase { class TFRTensorType : public detail::TFRTypeImpl { public: using TFRBase::TFRBase; + static constexpr StringLiteral name = "tfr.tensor"; static std::string getTypeName() { return "TFRTensorType"; } }; class TFRTensorListType : public detail::TFRTypeImpl { public: using TFRBase::TFRBase; + static constexpr StringLiteral name = "tfr.tensor_list"; static std::string getTypeName() { return "TFRTensorListType"; } }; class TFRAttrType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = "tfr.attr"; static std::string getTypeName() { return "TFRAttrType"; } }; diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index d41437a3fe796c..a73fc10ae083c4 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -1,5 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary") # Note: keep the following lines separate due to the way copybara works load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") @@ -19,7 +19,6 @@ package_group( name = "friends", packages = [ "//tensorflow/compiler/...", - "//tensorflow/core/runtime_fallback/...", "//tensorflow/core/tfrt/experimental/data/...", "//tensorflow/core/tfrt/graph_executor/...", "//tensorflow/core/tfrt/ifrt/...", @@ -127,164 +126,6 @@ cc_library( ], ) -cc_library( - name = "tf_ifrt_passes", - srcs = [ - "transforms/ifrt/rewrite_cluster_to_ifrt_call.cc", - "transforms/ifrt/tf_ifrt_passes.cc", - ], - hdrs = [ - "transforms/ifrt/rewrite_cluster_to_ifrt_call.h", - "transforms/ifrt/tf_ifrt_passes.h", - ], - #compatible_with = get_compatible_with_portable(), # copybara: comment - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:bridge_logger", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/core:framework", - "//tensorflow/core/platform:random", - "@com_google_absl//absl/base", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "tf2hlo", - srcs = ["transforms/ifrt/tf2hlo.cc"], - hdrs = ["transforms/ifrt/tf2hlo.h"], - deps = [ - "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/mlir/tf2xla/api/v2:legalize_tf", - "//tensorflow/compiler/tf2xla:layout_util", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:lib_headers_for_pybind", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", - "//tensorflow/core/tpu/kernels:tpu_compile_op_support", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_xla//xla:shape_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:client_library", - "@local_xla//xla/python/ifrt", - "@local_xla//xla/stream_executor", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - ], -) - -tf_cc_test( - name = "tf2hlo_test", - srcs = [ - "transforms/ifrt/tf2hlo_test.cc", - ], - data = [ - "//tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata", - ], - tags = ["no_oss"], - deps = [ - ":tf2hlo", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core/platform:resource_loader", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/python/ifrt", - "@local_xla//xla/python/ifrt:test_util", - "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", - ], -) - -cc_library( - name = "ifrt_backend_compiler", - srcs = ["transforms/ifrt/ifrt_backend_compiler.cc"], - hdrs = ["transforms/ifrt/ifrt_backend_compiler.h"], - deps = [ - ":backend_compiler", - ":tf_ifrt_passes", - ":tpu_passes", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:visitor", - "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf", - "//tensorflow/core/tfrt/ifrt:ifrt_executable_registry", - "//tensorflow/core/tfrt/ifrt:ifrt_model_context", - "//tensorflow/core/tfrt/ifrt:ifrt_serving_executable", - "//tensorflow/core/tfrt/runtime", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:traceme", - ], -) - -tf_cc_test( - name = "ifrt_backend_compiler_test", - srcs = [ - "transforms/ifrt/ifrt_backend_compiler_test.cc", - ], - data = [ - "//tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata", - ], - tags = ["no_oss"], - deps = [ - ":ifrt_backend_compiler", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/core:test", - "//tensorflow/core/platform:resource_loader", - "//tensorflow/core/tfrt/graph_executor:graph_execution_options", - "//tensorflow/core/tfrt/ifrt:ifrt_model_context", - "//tensorflow/core/tfrt/runtime", - "//tensorflow/core/tfrt/saved_model:saved_model_testutil", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/python/ifrt", - "@local_xla//xla/python/ifrt:test_util", - "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", - "@tf_runtime//:hostcontext", - ], -) - cc_library( name = "corert_converter", srcs = [ @@ -628,7 +469,7 @@ tf_proto_library( cc_library( name = "passes", visibility = [ - ":__subpackages__", + "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. ], deps = [ "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", @@ -656,7 +497,6 @@ cc_library( ":test_cost_analysis_pass", ":test_opkernels", ":test_tensor_array_side_effect_analysis", - ":tf_ifrt_passes", ":tf_to_tfrt", ":tpu_passes", ":transforms/gpu_passes", @@ -671,6 +511,7 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_sync_opdefs", "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:tf_ifrt_passes", "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:passes", "//tensorflow/core:tensorflow", "@llvm-project//mlir:AllPassesAndDialects", @@ -679,8 +520,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:gml_st", - "@local_xla//xla/mlir_hlo:gml_st_passes", "@tf_runtime//:init_tfrt_dialects", "@tf_runtime//:print_stream_pass", ], diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index e72c58bdd6b846..4b2b0576430bd1 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -70,7 +70,6 @@ td_library( ], includes = ["."], visibility = [ - # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", ], deps = [ diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h index 0fb568b44dc8c9..644de2618d691c 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h @@ -37,6 +37,7 @@ class FutureType : public mlir::Type::TypeBase { public: using Base::Base; + static constexpr mlir::StringLiteral name = "mlrt.compiler.future"; }; // The MLIR type represents a C++ mlrt::Promise. @@ -44,6 +45,7 @@ class PromiseType : public mlir::Type::TypeBase { public: using Base::Base; + static constexpr mlir::StringLiteral name = "mlrt.compiler.promise"; }; // The MLIR type represents a C++ mlrt::AsyncHandle. @@ -51,6 +53,7 @@ class AsyncHandleType : public mlir::Type::TypeBase { public: using Base::Base; + static constexpr mlir::StringLiteral name = "mlrt.compiler.async_handle"; }; } // namespace compiler diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h index da91450aa19fc1..a542373eeccf6a 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h @@ -41,6 +41,7 @@ class TFTensorType : public mlir::Type::TypeBase { public: using Base::Base; + static constexpr mlir::StringLiteral name = "tensorflow.tf_mlrt.tf_tensor"; }; // The MLIR type represents a tensorflow::Device* @@ -48,6 +49,7 @@ class TFDeviceType : public mlir::Type::TypeBase { public: using Base::Base; + static constexpr mlir::StringLiteral name = "tensorflow.tf_mlirt.tf_device"; }; } // namespace tf_mlrt diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h index 0d2e941b5cfb0d..24fa464ff6ed31 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h @@ -40,6 +40,7 @@ class FallbackDialect : public Dialect { class TFTensorType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = "tfrt.tf_tensor"; }; // The MLIR type represents a tensorflow::Allocator. @@ -47,6 +48,7 @@ class TFAllocatorType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = "tfrt.tf_allocator"; }; } // namespace fallback diff --git a/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD index 53a5a8489895a9..1065a5fc1a682a 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/BUILD @@ -1,4 +1,6 @@ -load("@tf_runtime//tools:mlir_to_bef.bzl", "glob_tfrt_lit_tests") +load("//tensorflow:tensorflow.bzl", "tf_cc_shared_test") +load("@tf_runtime//tools:mlir_to_bef.bzl", "glob_tfrt_lit_tests", "mlir_to_bef") +# copybara:uncomment load("//third_party/tf_runtime_google/cpp_tests:gen_tests.bzl", "tfrt_cc_test_and_strict_benchmark") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -9,6 +11,7 @@ filegroup( srcs = [ "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", "//tensorflow/core/runtime_fallback:tf_bef_executor", + "//tensorflow/core/runtime_fallback/util:fallback_test_util", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", "@llvm-project//mlir:run_lit.sh", @@ -23,6 +26,9 @@ filegroup( # # copybara:uncomment driver = "//tensorflow/compiler/mlir:run_lit.sh", # exclude = [ # "compile.benchmark.large.mlir", +# "batch_function_fallback.mlir", +# "create_op.mlir", +# "custom_thread_pool.mlir", # ], # # copybara:uncomment flaky = ["compile.error.mlir"], # size_override = { @@ -47,3 +53,91 @@ filegroup( # tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", # ) # copybara:uncomment_end + +mlir_to_bef( + name = "batch_function_fallback.mlir", + tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", +) + +mlir_to_bef( + name = "create_op.mlir", + tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", +) + +mlir_to_bef( + name = "custom_thread_pool.mlir", + tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", +) + +# copybara:uncomment_begin(internal benchmarking) +# # C++ benchmarks for batch function runtime fallback. +# tfrt_cc_test_and_strict_benchmark( +# name = "batch_function_fallback_benchmark_test", +# srcs = ["batch_function_fallback_benchmark_test.cc"], +# data = ["batch_function_fallback.mlir.bef"], +# enable_xprof = True, +# includes = ["third_party/tf_runtime/include"], +# owners = ["tf-runtime-testing"], +# tags = [ +# "need_main", +# "no_gpu", +# ], +# deps = [ +# "//base", +# "//devtools/build/runtime:get_runfiles_dir", +# "@com_google_absl//absl/log:check", +# "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", +# "//tensorflow/core/platform:env", +# "//tensorflow/core/platform:resource_loader", +# "//tensorflow/core/platform:status", +# "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", +# "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor", +# "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", +# "//tensorflow/core/runtime_fallback/util:fallback_test_util", +# "//tensorflow/core/runtime_fallback/util:tensor_util", +# "//tensorflow/core/tfrt/utils:fallback_tensor", +# "@eigen_archive//:eigen3", +# "@tf_runtime//:bef", +# "@tf_runtime//:befexecutor", +# "@tf_runtime//:core_runtime_alwayslink", +# "@tf_runtime//:hostcontext_alwayslink", +# "@tf_runtime//:mlirtobef", +# "@tf_runtime//:support", +# "@tf_runtime//:tensor", +# "@tf_runtime//backends/cpu:core_runtime_alwayslink", +# "@tf_runtime//backends/cpu:test_ops_alwayslink", +# ], +# ) +# copybara:uncomment_end + +tf_cc_shared_test( + name = "kernel_fallback_compat_test", + srcs = ["kernel_fallback_compat_test.cc"], + data = [ + "create_op.mlir.bef", + "custom_thread_pool.mlir.bef", + ], + tags = ["no_oss"], + deps = [ + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", + "//tensorflow/core:all_kernels", + "//tensorflow/core:lib", + "//tensorflow/core/platform:resource_loader", + "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", + "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", + "//tensorflow/core/runtime_fallback/util:fallback_test_util", + "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/runtime", + "//tensorflow/core/tfrt/utils:thread_pool", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tf_runtime//:bef", + "@tf_runtime//:befexecutor", + "@tf_runtime//:core_runtime", + "@tf_runtime//:hostcontext", + "@tf_runtime//:init_tfrt_dialects", + "@tf_runtime//:support", + "@tf_runtime//:tracing", + ], +) diff --git a/tensorflow/core/runtime_fallback/test/testdata/batch_function_fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/batch_function_fallback.mlir similarity index 100% rename from tensorflow/core/runtime_fallback/test/testdata/batch_function_fallback.mlir rename to tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/batch_function_fallback.mlir diff --git a/tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/batch_function_fallback_benchmark_test.cc similarity index 87% rename from tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc rename to tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/batch_function_fallback_benchmark_test.cc index 11bc0b6ecbf4f5..1d9d8f1e488984 100644 --- a/tensorflow/core/runtime_fallback/test/batch_function_fallback_benchmark_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/batch_function_fallback_benchmark_test.cc @@ -12,41 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include #include "base/logging.h" -#include "devtools/build/runtime/get_runfiles_dir.h" #include "testing/base/public/benchmark.h" -#include #include +#include "absl/log/check.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.h" -#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h" #include "tensorflow/core/runtime_fallback/util/fallback_test_util.h" -#include "tensorflow/core/runtime_fallback/util/tensor_util.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tfrt/bef/bef_buffer.h" // from @tf_runtime #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime -#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime +#include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime #include "tfrt/host_context/chain.h" // from @tf_runtime #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime #include "tfrt/host_context/function.h" // from @tf_runtime +#include "tfrt/host_context/host_allocator.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime -#include "tfrt/support/aligned_buffer.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime #include "tfrt/support/rc_array.h" // from @tf_runtime #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime -#include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime +#include "tfrt/tensor/tensor.h" // from @tf_runtime namespace tensorflow { -namespace tfd { namespace { // Creates a BEF file with a program that runs @@ -55,11 +52,11 @@ namespace { std::pair> CreateBefFile( tfrt::HostContext* host) { std::string file_path = GetDataDependencyFilepath( - "tensorflow/core/runtime_fallback/test/testdata/" + "tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/" "batch_function_fallback.mlir.bef"); std::string data; - TF_CHECK_OK(ReadFileToString(Env::Default(), file_path, &data)); + CHECK_OK(ReadFileToString(Env::Default(), file_path, &data)); tfrt::BefBuffer bef_buffer(data.begin(), data.end()); @@ -109,7 +106,7 @@ TEST(BatchFunctionTest, Basic) { auto arguments = CreateTestArguments(func, host); tfrt::ResourceContext resource_ctx; - auto exec_ctx = CreateFallbackTestExecutionContext(host, &resource_ctx); + auto exec_ctx = tfd::CreateFallbackTestExecutionContext(host, &resource_ctx); std::vector> results; results.resize(func->result_types().size()); @@ -141,7 +138,7 @@ void BM_BatchFunctionFallbackWithLargeAttributesAndManyInputsOutputs( auto arguments = CreateTestArguments(func, host); tfrt::ResourceContext resource_ctx; - auto exec_ctx = CreateFallbackTestExecutionContext(host, &resource_ctx); + auto exec_ctx = tfd::CreateFallbackTestExecutionContext(host, &resource_ctx); std::vector> results; results.resize(func->result_types().size()); @@ -157,5 +154,4 @@ void BM_BatchFunctionFallbackWithLargeAttributesAndManyInputsOutputs( BENCHMARK(BM_BatchFunctionFallbackWithLargeAttributesAndManyInputsOutputs); } // namespace -} // namespace tfd } // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/test/testdata/create_op.mlir b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/create_op.mlir similarity index 100% rename from tensorflow/core/runtime_fallback/test/testdata/create_op.mlir rename to tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/create_op.mlir diff --git a/tensorflow/core/runtime_fallback/test/testdata/custom_thread_pool.mlir b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/custom_thread_pool.mlir similarity index 100% rename from tensorflow/core/runtime_fallback/test/testdata/custom_thread_pool.mlir rename to tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/custom_thread_pool.mlir diff --git a/tensorflow/core/runtime_fallback/test/kernel_fallback_compat_test.cc b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/kernel_fallback_compat_test.cc similarity index 87% rename from tensorflow/core/runtime_fallback/test/kernel_fallback_compat_test.cc rename to tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/kernel_fallback_compat_test.cc index 75fae5c26e71ac..7b1a51f4fc664a 100644 --- a/tensorflow/core/runtime_fallback/test/kernel_fallback_compat_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/kernel_fallback_compat_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include -#include #include -#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/runtime_fallback/util/fallback_test_util.h" #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" @@ -30,14 +30,15 @@ limitations under the License. #include "tfrt/bef/bef_buffer.h" // from @tf_runtime #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/host_context/async_value.h" // from @tf_runtime #include "tfrt/host_context/chain.h" // from @tf_runtime #include "tfrt/host_context/function.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime -#include "tfrt/init_tfrt_dialects.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime #include "tfrt/tracing/tracing.h" // from @tf_runtime namespace tensorflow { -namespace tfd { namespace { // Creates a BEF file with a program that runs tfrt_fallback.batch_function with @@ -48,9 +49,9 @@ namespace { std::pair> CreateBefFile( absl::string_view file_name, tfrt::HostContext* host) { std::string file_path = GetDataDependencyFilepath(absl::StrCat( - "tensorflow/core/runtime_fallback/test/testdata/", file_name)); + "tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/", file_name)); std::string data; - TF_CHECK_OK(ReadFileToString(Env::Default(), file_path, &data)); + CHECK_OK(ReadFileToString(Env::Default(), file_path, &data)); tfrt::BefBuffer bef_buffer(data.begin(), data.end()); @@ -69,7 +70,7 @@ TEST(KernelFallbackCompatTest, CreateOp) { auto& bef_file = pair.second; tfrt::ResourceContext resource_ctx; - auto exec_ctx = CreateFallbackTestExecutionContext(host, &resource_ctx); + auto exec_ctx = tfd::CreateFallbackTestExecutionContext(host, &resource_ctx); auto chain = tfrt::GetReadyChain(); @@ -86,7 +87,7 @@ TEST(KernelFallbackCompatTest, CreateOp) { auto* fallback_request_state = exec_ctx.request_ctx() - ->GetDataIfExists(); + ->GetDataIfExists(); ASSERT_TRUE(fallback_request_state != nullptr); @@ -120,8 +121,8 @@ TEST(KernelFallbackCompatTest, CustomThreadPool) { tensorflow::tfrt_stub::TfThreadPool thread_pool(/*name=*/"test", /*num_threads=*/1); - auto exec_ctx = - CreateFallbackTestExecutionContext(host, &resource_ctx, &thread_pool); + auto exec_ctx = tfd::CreateFallbackTestExecutionContext(host, &resource_ctx, + &thread_pool); auto chain = tfrt::GetReadyChain(); @@ -146,5 +147,4 @@ TEST(KernelFallbackCompatTest, CustomThreadPool) { } } // namespace -} // namespace tfd } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc index 1ae3e8f1c54d31..a07558bac45f77 100644 --- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc @@ -33,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" -#include "xla/mlir_hlo/gml_st/IR/gml_st_ops.h" -#include "xla/mlir_hlo/gml_st/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tfrt/init_tfrt_dialects.h" // from @tf_runtime @@ -46,7 +44,6 @@ int main(int argc, char **argv) { mlir::registerTensorFlowPasses(); - mlir::gml_st::registerGmlStPasses(); tensorflow::mlrt_compiler::RegisterMlrtPasses(); tensorflow::ifrt_serving::RegisterTfIfrtPasses(); @@ -54,7 +51,6 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::RegisterAllTensorFlowDialects(registry); - registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD new file mode 100644 index 00000000000000..ec36fb683bc897 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -0,0 +1,179 @@ +load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + packages = [ + "//tensorflow/compiler/mlir/tfrt/...", + "//tensorflow/core/tfrt/ifrt/...", + "//tensorflow/core/tfrt/saved_model/tests/...", + ] + if_google([ + "//learning/brain/tfrt/cpp_tests/...", + # Allow visibility from the mlir language server. + "//learning/brain/mlir/mlir_lsp_server/...", + ]), +) + +cc_library( + name = "tf_ifrt_passes", + srcs = [ + "rewrite_cluster_to_ifrt_call.cc", + "tf_ifrt_passes.cc", + ], + hdrs = [ + "rewrite_cluster_to_ifrt_call.h", + "tf_ifrt_passes.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:bridge_logger", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/core:framework", + "//tensorflow/core/platform:random", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tf2hlo", + srcs = ["tf2hlo.cc"], + hdrs = ["tf2hlo.h"], + deps = [ + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/tf2xla/api/v2:legalize_tf", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/client:client_library", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/stream_executor", + "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + ], +) + +tf_cc_test( + name = "tf2hlo_test", + srcs = [ + "tf2hlo_test.cc", + ], + data = [ + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata", + ], + tags = ["no_oss"], + deps = [ + ":tf2hlo", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + ], +) + +cc_library( + name = "ifrt_backend_compiler", + srcs = ["ifrt_backend_compiler.cc"], + hdrs = ["ifrt_backend_compiler.h"], + deps = [ + ":tf_ifrt_passes", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:visitor", + "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf", + "//tensorflow/compiler/mlir/tfrt:backend_compiler", + "//tensorflow/compiler/mlir/tfrt:tpu_passes", + "//tensorflow/core/tfrt/ifrt:ifrt_executable_registry", + "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_serving_executable", + "//tensorflow/core/tfrt/runtime", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +tf_cc_test( + name = "ifrt_backend_compiler_test", + srcs = [ + "ifrt_backend_compiler_test.cc", + ], + data = [ + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata", + ], + tags = ["no_oss"], + deps = [ + ":ifrt_backend_compiler", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", + "//tensorflow/core/tfrt/graph_executor:graph_execution_options", + "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/runtime", + "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + "@tf_runtime//:hostcontext", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 16b1f0b7776160..978ffd25667b4c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -144,9 +144,6 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module); } - // TODO(b/305734600): conditionally running backward compat pass on host with - // tpu only. - // // Run backward compat pass so that we can use bridge to do clustering. auto backward_compat_result = tensorflow::RunTPUBackwardCompatConversion(module, {}); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_1in1out.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_1in1out.mlir deleted file mode 100644 index 8bd488bae251f2..00000000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_1in1out.mlir +++ /dev/null @@ -1,5 +0,0 @@ -module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { - func.func @main(%arg0: tensor<1x3xi32>) -> (tensor<1x3xi32>) { - func.return %arg0: tensor<1x3xi32> - } -} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_tuple.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_tuple.mlir new file mode 100644 index 00000000000000..f1eb5659fb97b8 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_tuple.mlir @@ -0,0 +1,6 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> (tensor<1x1xf32>, tensor<1x3xf32>) { + %0 = "tf.MatMul"(%arg0, %arg1): (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> + func.return %0, %arg0: tensor<1x1xf32>, tensor<1x3xf32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index bf661ab5be5630..246c920d64b964 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -21,10 +21,13 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project @@ -54,6 +57,8 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { +static constexpr absl::string_view kEntryFuncName = "main"; + absl::StatusOr> CompileTfToHlo( mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, xla::ifrt::Compiler* ifrt_compiler, @@ -89,7 +94,21 @@ absl::StatusOr> CompileTfToHlo( // supported. metadata_arg1->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER); } - metadata.add_retvals(); + + auto entry_fn = module.lookupSymbol(kEntryFuncName); + if (!entry_fn) { + return absl::InternalError("Could not find entry function in MLIR Module."); + } + + if (inputs.size() != entry_fn.getNumArguments()) { + return absl::InternalError( + absl::StrCat("Number of inputs mismatched! Expect", + entry_fn.getNumArguments(), " got", inputs.size())); + } + + for (int i = 0; i < entry_fn.getNumResults(); i++) { + metadata.add_retvals(); + } bool use_tuple_args = false; std::vector arg_core_mapping; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 8ad906ece7a2f4..ff2b4cebfb2530 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -72,12 +72,13 @@ TEST(Tf2HloTest, Basic) { TF_ASSERT_OK(result.status()); } -TEST(Tf2HloTest, 1in1out) { +// Multiple input and multiple out. +TEST(Tf2HloTest, Tuple) { // Create test input module constexpr absl::string_view kDataDirectory = "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( - absl::StrCat(kDataDirectory, "/tf2hlo_1in1out.mlir")); + absl::StrCat(kDataDirectory, "/tf2hlo_tuple.mlir")); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); @@ -95,8 +96,10 @@ TEST(Tf2HloTest, 1in1out) { xla::ifrt::test_util::GetClient()); std::vector tensors; - tensorflow::Tensor x(DT_INT32, tensorflow::TensorShape({1, 3})); + tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({1, 3})); + tensorflow::Tensor y(DT_FLOAT, tensorflow::TensorShape({3, 1})); tensors.push_back(x); + tensors.push_back(y); auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", client->GetDefaultCompiler(), tensorflow::IdentityShapeRepresentationFn()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index 03558438ac6f6b..90ab3af857c542 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -1,7 +1,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", "//tensorflow/core/tfrt:__subpackages__", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc index 6e04fe1c1e23a1..6dc48e4d6d137f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc @@ -68,7 +68,7 @@ StatusOr ConvertTfMlirToBytecode( TF_RETURN_IF_ERROR( ExportFunctionDefs(*copy, [flib_def](FunctionDef function_def) { VLOG(1) << "Exporting MLIR function as function_def: " - << function_def.DebugString(); + << function_def; // The TF MLIR compiler may change the function name. Then we // need to retrieve the original name from the diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index 8c85f9f80ac912..1953ddd3d93997 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -464,6 +464,8 @@ absl::StatusOr EmitExecutable( return status; } + buffer.shrink_to_fit(); + return buffer; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index e6ce181074de7f..d391f35e9adf77 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -142,7 +142,6 @@ tf_cc_binary( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", "@local_xla//xla/mlir_hlo:all_passes", - "@local_xla//xla/mlir_hlo:gml_st", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@stablehlo//:register", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h index c1a6daff83008d..e64ef8e2900f47 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -42,12 +42,15 @@ class OpKernelContextType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = + "kernel_gen.tf_framework.op_kernel_context"; }; class JITCallableType : public Type::TypeBase { public: using Base::Base; + static constexpr StringLiteral name = "kernel_gen.tf_framework.jit_callable"; }; absl::StatusCode ConvertAttrToEnumValue(ErrorCode error_code); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_abi_knowledge.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_abi_knowledge.mlir index 8619344681beac..47b5a122ef0dd2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_abi_knowledge.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_abi_knowledge.mlir @@ -25,44 +25,44 @@ module attributes {gpu.container_module} { // CHECK-LABEL: gpu.module @abs_kernel gpu.module @abs_kernel { // CHECK-LABEL: llvm.func @abs_kernel - // ABI-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr {llvm.align = 16 : index}, - // ABI-SAME: %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: !llvm.ptr, %[[ARG6:.*]]: !llvm.ptr {llvm.align = 16 : index, llvm.noalias}, + // ABI-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr {llvm.align = 16 : index}, + // ABI-SAME: %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: !llvm.ptr, %[[ARG6:.*]]: !llvm.ptr {llvm.align = 16 : index, llvm.noalias}, // ABI-SAME: %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64, %[[ARG9:.*]]: i64 - // SHAPE-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: !llvm.ptr, %[[ARG6:.*]]: !llvm.ptr, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64, %[[ARG9:.*]]: i64 - llvm.func @abs_kernel(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {gpu.kernel} { + // SHAPE-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: !llvm.ptr, %[[ARG6:.*]]: !llvm.ptr, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64, %[[ARG9:.*]]: i64 + llvm.func @abs_kernel(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {gpu.kernel} { // ABI: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) // ABI: %[[ONE:.*]] = llvm.mlir.constant(1 : index) // CHECK: llvm.mlir.undef - %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ARG1]] // SHAPE-NEXT: llvm.insertvalue %[[ARG0]] - %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[ARG1]] - %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ZERO]] // SHAPE-NEXT: llvm.insertvalue %[[ARG2]] - %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[ARG3]] - %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ONE]] // SHAPE-NEXT: llvm.insertvalue %[[ARG4]] - %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.mlir.undef - %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ARG6]] // SHAPE-NEXT: llvm.insertvalue %[[ARG5]] - %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %7 = llvm.insertvalue %arg5, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[ARG6]] - %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %8 = llvm.insertvalue %arg6, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ZERO]] // SHAPE-NEXT: llvm.insertvalue %[[ARG7]] - %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %9 = llvm.insertvalue %arg7, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ARG8]] // SHAPE-NEXT: llvm.insertvalue %[[ARG3]] - %10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %arg8, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // ABI-NEXT: llvm.insertvalue %[[ONE]] // SHAPE-NEXT: llvm.insertvalue %[[ARG4]] - %11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %arg9, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.return // CHECK-NEXT: llvm.return } @@ -93,63 +93,63 @@ module attributes {gpu.container_module} { // ABI-SAME: {llvm.align = 16 : index} // ABI-SAME: {llvm.align = 16 : index} // ABI-SAME: {llvm.align = 16 : index, llvm.noalias} - llvm.func @AddV2_kernel(%arg0: i64, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: !llvm.ptr, %arg12: !llvm.ptr, %arg13: i64, %arg14: i64, %arg15: i64) attributes {gpu.kernel} { + llvm.func @AddV2_kernel(%arg0: i64, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: !llvm.ptr, %arg12: !llvm.ptr, %arg13: i64, %arg14: i64, %arg15: i64) attributes {gpu.kernel} { // ABI: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 // ABI: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 - %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %1 = llvm.insertvalue %arg1, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %2 = llvm.insertvalue %arg2, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %3 = llvm.insertvalue %arg3, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %4 = llvm.insertvalue %arg4, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP:.*]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[STR:.*]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %7 = llvm.insertvalue %arg6, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %8 = llvm.insertvalue %arg7, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %9 = llvm.insertvalue %arg8, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %10 = llvm.insertvalue %arg9, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %11 = llvm.insertvalue %arg10, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR1:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR1]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %12 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %13 = llvm.insertvalue %arg11, %12[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %14 = llvm.insertvalue %arg12, %13[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %15 = llvm.insertvalue %arg13, %14[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %16 = llvm.insertvalue %arg14, %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %17 = llvm.insertvalue %arg15, %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR2:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR2]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.insertvalue %arg1, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.insertvalue %arg2, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %3 = llvm.insertvalue %arg3, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %4 = llvm.insertvalue %arg4, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP:.*]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[STR:.*]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %7 = llvm.insertvalue %arg6, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %8 = llvm.insertvalue %arg7, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %9 = llvm.insertvalue %arg8, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %arg9, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %arg10, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR1:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR1]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %13 = llvm.insertvalue %arg11, %12[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.insertvalue %arg12, %13[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %15 = llvm.insertvalue %arg13, %14[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %arg14, %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.insertvalue %arg15, %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR2:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR2]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.return // CHECK-NEXT: llvm.return } @@ -181,80 +181,80 @@ module attributes {gpu.container_module} { // ABI-SAME: {llvm.align = 16 : index, llvm.noalias} // ABI-SAME: {llvm.align = 16 : index} // ABI-SAME: {llvm.align = 16 : index} - llvm.func @AddV2_kernel(%arg0: i64, %arg1: i64, %arg2: !llvm.ptr, %arg3: !llvm.ptr {llvm.align = 16 : index, llvm.noalias}, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64, %arg9: !llvm.ptr, %arg10: !llvm.ptr {llvm.align = 16 : index}, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: i64, %arg16: !llvm.ptr, %arg17: !llvm.ptr {llvm.align = 16 : index}, %arg18: i64, %arg19: i64, %arg20: i64, %arg21: i64, %arg22: i64) attributes {gpu.kernel} { + llvm.func @AddV2_kernel(%arg0: i64, %arg1: i64, %arg2: !llvm.ptr, %arg3: !llvm.ptr {llvm.align = 16 : index, llvm.noalias}, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64, %arg9: !llvm.ptr, %arg10: !llvm.ptr {llvm.align = 16 : index}, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: i64, %arg16: !llvm.ptr, %arg17: !llvm.ptr {llvm.align = 16 : index}, %arg18: i64, %arg19: i64, %arg20: i64, %arg21: i64, %arg22: i64) attributes {gpu.kernel} { // ABI: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 // ABI: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 - %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %1 = llvm.insertvalue %arg2, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %2 = llvm.insertvalue %arg3, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %3 = llvm.insertvalue %arg4, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %4 = llvm.insertvalue %arg5, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %5 = llvm.insertvalue %arg7, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %6 = llvm.insertvalue %arg6, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %7 = llvm.insertvalue %arg8, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP0:.*]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[STR0:.*]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP1:.*]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[STR1:.*]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %9 = llvm.insertvalue %arg9, %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %10 = llvm.insertvalue %arg10, %9[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %11 = llvm.insertvalue %arg11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %12 = llvm.insertvalue %arg12, %11[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %13 = llvm.insertvalue %arg14, %12[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %14 = llvm.insertvalue %arg13, %13[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %15 = llvm.insertvalue %arg15, %14[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NOT: llvm.insertvalue %[[C1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP0]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NOT: llvm.insertvalue %[[STR0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE: llvm.insertvalue %[[SHP1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NOT: llvm.insertvalue %[[STR1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %16 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %17 = llvm.insertvalue %arg16, %16[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %18 = llvm.insertvalue %arg17, %17[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %19 = llvm.insertvalue %arg18, %18[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %20 = llvm.insertvalue %arg19, %19[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %21 = llvm.insertvalue %arg21, %20[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %22 = llvm.insertvalue %arg20, %21[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %23 = llvm.insertvalue %arg22, %22[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // ABI-NOT: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP0]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NOT: llvm.insertvalue %[[STR0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE: llvm.insertvalue %[[SHP1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // SHAPE-NOT: llvm.insertvalue %[[STR1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %1 = llvm.insertvalue %arg2, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %2 = llvm.insertvalue %arg3, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %3 = llvm.insertvalue %arg4, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %4 = llvm.insertvalue %arg5, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %5 = llvm.insertvalue %arg7, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %6 = llvm.insertvalue %arg6, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %7 = llvm.insertvalue %arg8, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP0:.*]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[STR0:.*]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP1:.*]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[STR1:.*]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %9 = llvm.insertvalue %arg9, %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %10 = llvm.insertvalue %arg10, %9[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %11 = llvm.insertvalue %arg11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %12 = llvm.insertvalue %arg12, %11[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %13 = llvm.insertvalue %arg14, %12[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %14 = llvm.insertvalue %arg13, %13[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %15 = llvm.insertvalue %arg15, %14[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NOT: llvm.insertvalue %[[C1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP0]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NOT: llvm.insertvalue %[[STR0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE: llvm.insertvalue %[[SHP1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NOT: llvm.insertvalue %[[STR1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %16 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %17 = llvm.insertvalue %arg16, %16[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %18 = llvm.insertvalue %arg17, %17[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %19 = llvm.insertvalue %arg18, %18[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %20 = llvm.insertvalue %arg19, %19[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %21 = llvm.insertvalue %arg21, %20[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %22 = llvm.insertvalue %arg20, %21[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %23 = llvm.insertvalue %arg22, %22[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // ABI-NOT: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP0]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NOT: llvm.insertvalue %[[STR0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE: llvm.insertvalue %[[SHP1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // SHAPE-NOT: llvm.insertvalue %[[STR1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> llvm.return // CHECK: llvm.return } @@ -289,63 +289,63 @@ module attributes {gpu.container_module} { // ABI-SAME: {llvm.align = 16 : index, llvm.noalias} // ABI-SAME: {llvm.align = 16 : index} // ABI-SAME: {llvm.align = 16 : index} - llvm.func @AddV2_kernel(%arg0: i64, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: !llvm.ptr, %arg12: !llvm.ptr, %arg13: i64, %arg14: i64, %arg15: i64) attributes {gpu.kernel} { + llvm.func @AddV2_kernel(%arg0: i64, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: !llvm.ptr, %arg12: !llvm.ptr, %arg13: i64, %arg14: i64, %arg15: i64) attributes {gpu.kernel} { // ABI: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 // ABI: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 - %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %1 = llvm.insertvalue %arg1, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %2 = llvm.insertvalue %arg2, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %3 = llvm.insertvalue %arg3, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %4 = llvm.insertvalue %arg4, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP:.*]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[STR:.*]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %7 = llvm.insertvalue %arg6, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %8 = llvm.insertvalue %arg7, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %9 = llvm.insertvalue %arg8, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %10 = llvm.insertvalue %arg9, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %11 = llvm.insertvalue %arg10, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR1:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR1]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NOT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %12 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %13 = llvm.insertvalue %arg11, %12[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %14 = llvm.insertvalue %arg12, %13[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %15 = llvm.insertvalue %arg13, %14[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %16 = llvm.insertvalue %arg14, %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %17 = llvm.insertvalue %arg15, %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR2:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[PTR2]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // SHAPE-NOT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.insertvalue %arg1, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.insertvalue %arg2, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %3 = llvm.insertvalue %arg3, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %4 = llvm.insertvalue %arg4, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR0]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP:.*]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[STR:.*]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %7 = llvm.insertvalue %arg6, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %8 = llvm.insertvalue %arg7, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %9 = llvm.insertvalue %arg8, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %arg9, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %arg10, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR1:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR1]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NOT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %13 = llvm.insertvalue %arg11, %12[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.insertvalue %arg12, %13[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %15 = llvm.insertvalue %arg13, %14[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %arg14, %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.insertvalue %arg15, %16[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR2:.*]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[PTR2]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ABI-NEXT: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NEXT: llvm.insertvalue %[[SHP]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // SHAPE-NOT: llvm.insertvalue %[[STR]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.return // CHECK: llvm.return } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc index 681896f2a235a7..178e899cb33a72 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "xla/mlir_hlo/gml_st/IR/gml_st_ops.h" #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -37,8 +36,7 @@ int main(int argc, char **argv) { mlir::stablehlo::registerAllDialects(registry); mlir::RegisterAllTensorFlowDialects(registry); - registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 7cf5ef8522bb23..7c2e9d45d12db9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -1,13 +1,13 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -155,7 +155,6 @@ cc_library( "@local_xla//xla:debug_options_flags", "@local_xla//xla:xla_proto_cc", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:gml_st", "@local_xla//xla/mlir_hlo:lhlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:type_conversion", @@ -218,7 +217,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo:gml_st", "@local_xla//xla/mlir_hlo:lhlo", "@local_xla//xla/mlir_hlo:transforms_passes", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index 136b278e8c9dcf..b002effdfccf89 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project @@ -178,7 +179,7 @@ LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite( name_buffer.append("_blob"); Value module_blob = LLVM::createGlobalString(loc, rewriter, name_buffer.str(), binary_attr.getValue(), - LLVM::Linkage::Internal, true); + LLVM::Linkage::Internal); // Make sure the trailing zero is included in the constant. auto kernel_name = launch_op.getKernelName().getValue(); @@ -192,7 +193,7 @@ LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite( .toStringRef(kernel_name_global_name_buffer); auto kernel_name_global = LLVM::createGlobalString( loc, rewriter, kernel_name_global_name, kernel_name_buffer, - LLVM::Linkage::Internal, true); + LLVM::Linkage::Internal); // The TensorFlow OpKernelContext is the first argument of the surrounding // LLVMFunc. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc index b1c909bb52364c..b3cb73b78baf20 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -61,7 +62,7 @@ Value CreateOrFindGlobalStringConstant(Location loc, StringRef global_name, ValueRange{c0, c0}); } return LLVM::createGlobalString(loc, *b, global_name, content, - LLVM::Linkage::Internal, true); + LLVM::Linkage::Internal); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 9eca865c4a91fd..d53604011273d8 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -186,8 +186,8 @@ cc_library( "transforms/convert_metadata.cc", "transforms/convert_tfl_uint8.cc", "transforms/legalize_tfl.cc", + "transforms/legalize_tfl_stateful.cc", "transforms/lower_complex_types.cc", - "transforms/lower_global_tensors.cc", "transforms/retain_call_once_funcs.cc", "transforms/strip_metadata.cc", "transforms/strip_quant_types.cc", @@ -213,7 +213,6 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", - "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ReconcileUnrealizedCasts", diff --git a/tensorflow/compiler/mlir/tosa/tests/lower_global_tensors.mlir b/tensorflow/compiler/mlir/tosa/tests/lower_global_tensors.mlir deleted file mode 100644 index 5b8bd2cc3c09a2..00000000000000 --- a/tensorflow/compiler/mlir/tosa/tests/lower_global_tensors.mlir +++ /dev/null @@ -1,145 +0,0 @@ -// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(tflite-lower-global-tensors)' %s | FileCheck %s - -module { - // CHECK: ml_program.global private mutable @Variable(dense<1.000000e+00> : tensor<16x16xf32>) - // CHECK-LABEL: func.func @state - func.func @state(%arg0: tensor<16x16xf32>) -> () { - "tfl.call_once"() {session_init_function = "StateInit"} : () -> () - return - } - - func.func private @StateInit() { - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> - "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return - } -} - -// ----- - -module { - // CHECK: ml_program.global private mutable @Variable(dense<1.000000e+00> : tensor<16x16xf32>) - - // CHECK-LABEL: func.func @assign - func.func @assign(%arg0: tensor<16x16xf32>) -> () { - "tfl.call_once"() {session_init_function = "AssignInit"} : () -> () - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - - // CHECK: ml_program.global_store @Variable = %arg0 - "tfl.assign_variable"(%0, %arg0) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return - } - - func.func private @AssignInit() { - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> - "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return - } -} - -// ----- - -module { - // CHECK: ml_program.global private mutable @Variable(dense<1.000000e+00> : tensor<16x16xf32>) - - // CHECK-LABEL: func.func @read - func.func @read(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { - "tfl.call_once"() {session_init_function = "ReadInit"} : () -> () - - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - - // CHECK: %[[LOAD:.+]] = ml_program.global_load @Variable : tensor<16x16xf32> - %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> - return %1 : tensor<16x16xf32> - } - - func.func private @ReadInit() { - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> - "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return - } -} - -// ----- - -module { - // CHECK: ml_program.global private mutable @Variable(dense<2.000000e+00> : tensor<16x16xf32>) - - // CHECK-LABEL: func.func @readAssign - func.func @readAssign(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { - "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> () - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - - // CHECK: %[[LOAD:.+]] = ml_program.global_load @Variable : tensor<16x16xf32> - %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> - - // CHECK: %[[ADD:.+]] = tfl.add %[[LOAD]], %arg0 - %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32> - - // CHECK: ml_program.global_store @Variable = %[[ADD]] - "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return %2 : tensor<16x16xf32> - } - func.func private @ReadAssignInit() { - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - %1 = "tfl.pseudo_const"() {value = dense<2.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> - "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return - } -} - -// ----- - -module { - // CHECK: ml_program.global private mutable @Variable(dense<42> : tensor<2x3xi8>) - // CHECK-LABEL: func.func @readAssignQuant - func.func @readAssignQuant(%arg0: tensor<2x3x!quant.uniform>) -> (tensor<2x3x!quant.uniform>) { - "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> () - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - - // CHECK: %[[ADDR:.+]] = ml_program.global_load @Variable : tensor<2x3xi8> - // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ADDR]] : tensor<2x3xi8> to tensor<2x3x!quant.uniform> - %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<2x3x!quant.uniform> - - // CHECK: %[[ADD:.+]] = tfl.add %[[CAST]], %arg0 {fused_activation_function = "NONE"} - %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<2x3x!quant.uniform> - - // CHECK: %[[CAST2:.+]] = builtin.unrealized_conversion_cast %[[ADD]] : tensor<2x3x!quant.uniform> to tensor<2x3xi8> - // CHECK: ml_program.global_store @Variable = %[[CAST2]] - "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<2x3x!quant.uniform>) -> () - return %2 : tensor<2x3x!quant.uniform> - } - func.func private @ReadAssignInit() { - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - %1 = "tfl.pseudo_const"() {qtype = tensor<2x3x!quant.uniform>, value = dense<42> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> - "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<2x3x!quant.uniform>) -> () - return - } -} - -// ----- - -module { - // CHECK-label: @nostate - func.func @nostate(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { - "tfl.call_once"() {session_init_function = "NoStateInit"} : () -> () - // CHECK: tfl.var_handle - %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> - - // CHECK: tfl.read_variable - %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> - - %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32> - - // CHECK: tfl.assign_variable - "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () - return %2 : tensor<16x16xf32> - } - func.func private @NoStateInit() { - return - } -} - diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 84e61d0de2f7cb..c6c7e649e971a6 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -271,18 +271,20 @@ func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2 // CHECK-LABEL: test_conv3d_qi8( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x8x21x17x!quant.uniform> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x17x34xf32>) -> tensor<1x4x8x11x34x!quant.uniform> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.11982894> : tensor<1x1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<-4.000000e+00> : tensor<1x1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<34xf32>}> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>}> -// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_0]] -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_7]], %[[VAL_2]] {shift = 0 : i8} -// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] -// CHECK: %[[VAL_10:.*]] = tosa.conv3d %[[VAL_8]], %[[VAL_9]], %[[VAL_5]] {dilation = array, pad = array, stride = array} -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_3]] {shift = 0 : i8} -// CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_11]], %[[VAL_4]] -// CHECK: %[[VAL_13:.*]] = tosa.cast %[[VAL_12]] +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.11982894> : tensor<1x1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<-4> : tensor<1x1x1x1x1xi32>} +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<34xf32>} +// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_0]] +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] +// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[VAL_6]] {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_4]] {shift = 0 : i8} +// CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] +// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_14]], %[[VAL_5]] +// CHECK: %[[VAL_16:.*]] = tosa.cast %[[VAL_15]] +// CHECK: return %[[VAL_16]] func.func @test_conv3d_qi8(%arg0: tensor<1x4x8x21x17x!quant.uniform>, %arg1: tensor<2x3x3x17x34xf32>) -> (tensor<1x4x8x11x34x!quant.uniform>) { %0 = "tfl.dequantize"(%arg0) : (tensor<1x4x8x21x17x!quant.uniform>) -> tensor<1x4x8x21x17xf32> %2 = "tfl.no_value"() {value} : () -> none @@ -1853,12 +1855,12 @@ func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tenso // ----- // CHECK-LABEL: test_fakequant_with_min_max_args -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR2:.*]] = tosa.mul %arg0, %[[VAR0]] {shift = 0 : i8} -// CHECK-DAG: %[[VAR3:.*]] = tosa.cast %[[VAR2]] -// CHECK-DAG: %[[VAR4:.*]] = tosa.cast %[[VAR3]] -// CHECK-DAG: %[[VAR5:.*]] = tosa.mul %[[VAR4]], %[[VAR1]] {shift = 0 : i8} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR3:.*]] = tosa.mul %arg0, %[[VAR2]] {shift = 0 : i8} +// CHECK-DAG: %[[VAR5:.*]] = tosa.cast %[[VAR3]] +// CHECK-DAG: %[[VAR6:.*]] = tosa.cast %[[VAR5]] +// CHECK-DAG: %[[VAR8:.*]] = tosa.mul %[[VAR6]], %[[VAR1]] {shift = 0 : i8} func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.quantize"(%arg0) {qtype = tensor<13x21x3x!quant.uniform>} : (tensor<13x21x3xf32>) -> tensor<*x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<*x!quant.uniform>) -> tensor<13x21x3xf32> @@ -2662,7 +2664,7 @@ func.func @test_reverse_fail(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> // CHECK-LABEL: test_tfl_custom // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x64x64x32xf32> -// CHECK: %[[VAL_0:.*]] = tosa.custom %[[ARG_0]] {config = "TFL", identifier = "MaxPoolingWithArgmax2D", implementation_attrs = "{{.*}}"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) +// CHECK: %[[VAL_0:.*]] = tosa.custom %[[ARG_0]] {domain_name = "TFL", implementation_attrs = "{{.*}}", operator_name = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) func.func @test_tfl_custom(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { // custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) %0, %1 = "tfl.custom"(%arg0) {custom_option = #tfl, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir new file mode 100644 index 00000000000000..e0f2d6b3ede707 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir @@ -0,0 +1,84 @@ +// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + +// Operations for testing tfl-to-tosa-pipeline + +// ----- + +module attributes {tf_saved_model.semantics, tfl.description = "Test.", tfl.schema_version = 3 : i32} { + // CHECK: tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32> + // CHECK-LABEL: test_stateful_ops + // CHECK: tosa.variable.write @var_x, %arg0 : tensor<1xf32> + // CHECK: %[[VAL_0:.*]] = tosa.variable.read @var_x : tensor<1xf32> + // CHECK: return %[[VAL_0]] : tensor<1xf32> + func.func @test_stateful_ops(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["placeholder_0"]}) + -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) + attributes {tf_saved_model.exported_names = ["serving_default"]} { + "tfl.call_once"() {session_init_function = "InitializeX"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "var_x"} : () -> tensor + "tfl.assign_variable"(%0, %arg0) : (tensor, tensor<1xf32>) -> () + %1 = "tfl.read_variable"(%0) : (tensor) -> tensor<1xf32> + return %1 : tensor<1xf32> + } + + // initialize variable var_x to 7.0 + func.func private @InitializeX() { + %0 = "tfl.var_handle"() {container = "", shared_name = "var_x"} : () -> tensor + %1 = "tfl.pseudo_const"() {value = dense<7.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + "tfl.assign_variable"(%0, %1) : (tensor, tensor<1xf32>) -> () + return + } +} + +// ----- + +module { + // CHECK: tosa.variable @Variable = dense<42> : tensor<2x3xi8> + // CHECK-LABEL: readAssignQuant + // CHECK: %[[VAL_0:.*]] = tosa.variable.read @Variable : tensor<2x3xi8> + // CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : tensor<2x3xi8> to tensor<2x3x!quant.uniform> + // CHECK: %[[VAL_2:.*]] = tosa.rescale %[[VAL_1]] {double_round = true, input_zp = 2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<2x3x!quant.uniform>) -> tensor<2x3xi32> + // CHECK: %[[VAL_3:.*]] = tosa.rescale %[[VAL_4:.*]] {double_round = true, input_zp = 2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<2x3x!quant.uniform>) -> tensor<2x3xi32> + // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_2]], %[[VAL_3]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: %[[VAL_6:.*]] = tosa.rescale %[[VAL_5]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 2 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<2x3xi32>) -> tensor<2x3x!quant.uniform> + // CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : tensor<2x3x!quant.uniform> to tensor<2x3xi8> + // CHECK: tosa.variable.write @Variable, %[[VAL_7]] : tensor<2x3xi8> + // CHECK: return %[[VAL_6]] : tensor<2x3x!quant.uniform> + func.func @readAssignQuant(%arg0: tensor<2x3x!quant.uniform>) -> (tensor<2x3x!quant.uniform>) { + "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<2x3x!quant.uniform> + %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<2x3x!quant.uniform> + "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<2x3x!quant.uniform>) -> () + return %2 : tensor<2x3x!quant.uniform> + } + func.func private @ReadAssignInit() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {qtype = tensor<2x3x!quant.uniform>, value = dense<42> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<2x3x!quant.uniform>) -> () + return + } +} + +// ----- + +module { + // CHECK-LABEL: @nostate + // CHECK: %[[VAL_0:.*]]: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: %[[VAL_1:.*]] = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + // CHECK: %[[VAL_2:.*]] = "tfl.read_variable"(%[[VAL_1]]) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> + // CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_2]], %[[VAL_0]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + // CHECK: "tfl.assign_variable"(%[[VAL_1]], %[[VAL_3]]) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + // CHECK: return %[[VAL_3]] : tensor<16x16xf32> + func.func @nostate(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { + "tfl.call_once"() {session_init_function = "NoStateInit"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> + %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32> + "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return %2 : tensor<16x16xf32> + } + func.func private @NoStateInit() { + return + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc b/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc index 2b31e3246fd598..81ea9f6393216c 100644 --- a/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc @@ -30,6 +30,9 @@ void createTFTFLtoTOSALegalizationPipeline( //---------------------------------------------------------------------------- // Prepare TFL module for conversion //---------------------------------------------------------------------------- + // For stateful ops + pm.addPass(createRetainCallOnceFuncsPass()); + // Inline all functions into main and then delete the functions themselves. pm.addPass(mlir::createInlinerPass()); @@ -52,6 +55,7 @@ void createTFTFLtoTOSALegalizationPipeline( if (opts.dequantize_tfl_softmax) { pm.addPass(mlir::tosa::createDequantizeTFLSoftmaxPass()); } + pm.addPass(mlir::tosa::createLegalizeTFLStatefulPass()); pm.addPass(mlir::tosa::createLegalizeTFTFLPass()); //---------------------------------------------------------------------------- diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.cc b/tensorflow/compiler/mlir/tosa/tfl_passes.cc index ff3c38e381e8ba..2eb98a4415f668 100644 --- a/tensorflow/compiler/mlir/tosa/tfl_passes.cc +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.cc @@ -30,16 +30,14 @@ void createTFLtoTOSALegalizationPipeline( //---------------------------------------------------------------------------- // Prepare TFL module for conversion //---------------------------------------------------------------------------- - if (opts.target_compilation_backend) { - pm.addPass(createRetainCallOnceFuncsPass()); - } + pm.addPass(createRetainCallOnceFuncsPass()); + // Inline all functions into main and then delete the functions themselves. pm.addPass(mlir::createInlinerPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createSymbolDCEPass()); if (opts.target_compilation_backend) { pm.nest().addPass(createConvertFunctionMetadataPass()); - pm.addPass(createLowerGlobalTensorsPass()); } // Add pass to decompose TFLite mixed quantization to non-quantized variants. @@ -59,6 +57,7 @@ void createTFLtoTOSALegalizationPipeline( if (opts.dequantize_tfl_softmax) { pm.addPass(mlir::tosa::createDequantizeTFLSoftmaxPass()); } + pm.addPass(mlir::tosa::createLegalizeTFLStatefulPass()); pm.addPass(mlir::tosa::createLegalizeTFLPass(opts.disabled_patterns, opts.enabled_patterns)); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 7fed578c78f86c..b454dfecbaca98 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -3514,21 +3514,26 @@ std::optional convertQuantizeOp(PatternRewriter& rewriter, Operation* op, } ShapedType output_fp_type = output_type.clone(rewriter.getF32Type()); - - Value zp_val = - getTosaConstTensorSingleF32(rewriter, op, static_cast(zeropoint)); - - auto op1_mul_in = CreateOpAndInfer( + Value result = CreateOpAndInfer( rewriter, op->getLoc(), output_fp_type, input_value, getTosaConstTensorSingleF32(rewriter, op, static_cast(scale)), 0); - auto op2_add_op1 = CreateOpAndInfer( - rewriter, op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val); + if (zeropoint != 0) { + // cast to i32 to add zeropoint + ShapedType output_i32_type = output_type.clone(rewriter.getI32Type()); + Value cast_i32 = CreateOpAndInfer(rewriter, op->getLoc(), + output_i32_type, result); + + Value zp_val = getTosaConstTensorSingleI32(rewriter, op, zeropoint); + + result = CreateOpAndInfer(rewriter, op->getLoc(), + output_i32_type, cast_i32, zp_val); + } - auto op3_cast_op2 = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, op2_add_op1.getResult()); + Value final_result = CreateOpAndInfer(rewriter, op->getLoc(), + output_type, result); - return op3_cast_op2.getResult(); + return final_result; } // Lowers Dequantize to a sequence of TOSA dequantization ops. diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl_stateful.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl_stateful.cc new file mode 100644 index 00000000000000..4028093f547a3a --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl_stateful.cc @@ -0,0 +1,187 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legalize TensorFlow Lite StatefulOps to TOSA + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" + +#define PASS_NAME "tosa-legalize-tfl-stateful" + +namespace mlir { +namespace tosa { +namespace { + +#define GEN_PASS_DEF_TOSALEGALIZETFLSTATEFULPASS +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +// Performs lowering tfl stateful operators to TOSA +class TosaLegalizeTFLStateful + : public impl::TosaLegalizeTFLStatefulPassBase { + public: + explicit TosaLegalizeTFLStateful() = default; + void runOnOperation() override; +}; + +void TosaLegalizeTFLStateful::runOnOperation() { + auto moduleOp = getOperation(); + mlir::OpBuilder builder(moduleOp.getBodyRegion()); + + DenseMap symNameToFunction; + for (auto func : moduleOp.getOps()) { + symNameToFunction[func.getSymName()] = func; + } + + llvm::SmallVector handleOps; + llvm::SmallVector assignOps; + llvm::SmallVector readOps; + SmallVector callOnceOps; + DenseMap symbolRefMap; + + for (auto it : symNameToFunction) { + auto func = std::get<1>(it); + // We also want to grab the list of operations to replace. + for (auto& op : func.getOps()) { + if (auto handle = dyn_cast(op)) + handleOps.push_back(handle); + if (auto assign = dyn_cast(op)) + assignOps.push_back(assign); + if (auto read = dyn_cast(op)) + readOps.push_back(read); + } + } + + for (auto func : moduleOp.getOps()) { + for (auto init : func.getOps()) { + callOnceOps.push_back(init); + } + } + + // Look through the initialization functions and find the assigned values + // for each handle, save out the constant value. + for (auto init : callOnceOps) { + auto findInitFunc = + symNameToFunction.find(init.getSessionInitFunctionAttr()); + if (findInitFunc == symNameToFunction.end()) { + init.emitError("unable to find initialization function: "); + continue; + } + func::FuncOp initFunc = std::get<1>(*findInitFunc); + for (auto assign : initFunc.getOps()) { + // 1. var_handle part + auto handle = dyn_cast( + assign.getResourceId().getDefiningOp()); + if (!handle) continue; + + // 2. pseudo_const part + DenseElementsAttr constant; + if (!matchPattern(assign.getValue(), m_Constant(&constant))) { + // Quantized types we can not use the m_Constant matcher. + if (auto constOp = dyn_cast( + assign.getValue().getDefiningOp())) { + constant = cast(constOp.getValue()); + } + } + if (!constant) continue; + + // Create TOSA VariableOps + auto name = handle.getSharedName(); + auto global = builder.create( + handle.getLoc(), name, constant.getType(), constant); + symbolRefMap[name] = global; + } + } + // TF::CallOnceOps are no longer needed as we have already extracted their + // state. + for (auto op : callOnceOps) op.erase(); + + // Replace the assign ops with a tosa store operation. + for (auto assign : assignOps) { + auto handle = dyn_cast( + assign.getResourceId().getDefiningOp()); + if (!handle) continue; + + Value value = assign.getValue(); + auto globalOpIt = symbolRefMap.find(handle.getSharedName()); + if (globalOpIt == symbolRefMap.end()) { + assign->emitError( + "unable to find corresponding TosaOp for op's VarHandle"); + continue; + } + auto globalOp = std::get<1>(*globalOpIt); + + builder.setInsertionPoint(assign); + if (globalOp.getType() != value.getType()) { + value = builder + .create(assign.getLoc(), + globalOp.getType(), value) + .getResult(0); + } + + builder.create( + assign.getLoc(), llvm::StringRef(globalOp.getName()), value); + assign.erase(); + } + + for (auto read : readOps) { + auto handle = + dyn_cast(read.getResourceId().getDefiningOp()); + if (!handle) continue; + + auto globalOpIt = symbolRefMap.find(handle.getSharedName()); + if (globalOpIt == symbolRefMap.end()) continue; + auto globalOp = std::get<1>(*globalOpIt); + + builder.setInsertionPoint(read); + + Value load = builder.create( + read.getLoc(), globalOp.getType(), llvm::StringRef(globalOp.getName())); + + if (read.getType() != load.getType()) { + load = builder + .create(read.getLoc(), + read.getType(), load) + .getResult(0); + } + read.getResult().replaceAllUsesWith(load); + read.erase(); + } + + for (auto handle : handleOps) { + if (handle.getResult().use_empty()) { + handle.erase(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow Lite dialect LegalizeTFLStateful pass. +std::unique_ptr> createLegalizeTFLStatefulPass() { + return std::make_unique(); +} + +} // namespace tosa +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/lower_global_tensors.cc b/tensorflow/compiler/mlir/tosa/transforms/lower_global_tensors.cc deleted file mode 100644 index de30f7c2fb0507..00000000000000 --- a/tensorflow/compiler/mlir/tosa/transforms/lower_global_tensors.cc +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" - -#define PASS_NAME "tosa-lower-global-tensors" -#define DEBUG_TYPE PASS_NAME - -namespace mlir::tosa { - -#define GEN_PASS_DEF_LOWERGLOBALTENSORS -#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" - -namespace { - -class LowerGlobalTensorsPass - : public impl::LowerGlobalTensorsBase { - public: - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - // Converts TFLite state operations to the MLProgram equivalent. - void runOnOperation() override { - auto* context = &getContext(); - auto moduleOp = getOperation(); - mlir::OpBuilder builder(moduleOp.getBodyRegion()); - - DenseMap symNameToFunction; - for (auto func : moduleOp.getOps()) { - symNameToFunction[func.getSymName()] = func; - } - - DenseMap sharedNameToConstant; - DenseMap sharedNameToLoc; - - SmallVector handleOps; - SmallVector assignOps; - SmallVector readOps; - for (auto it : symNameToFunction) { - auto func = std::get<1>(it); - // Look through the initialization functions and find the assigned values - // for each handle, save out the constant value. - for (auto init : func.getOps()) { - auto findInitFunc = - symNameToFunction.find(init.getSessionInitFunction()); - if (findInitFunc == symNameToFunction.end()) { - init.emitError("unable to find initialization function: " + - init.getSessionInitFunction()); - continue; - } - func::FuncOp initFunc = std::get<1>(*findInitFunc); - for (auto assign : initFunc.getOps()) { - auto handle = dyn_cast( - assign.getResourceId().getDefiningOp()); - if (!handle) continue; - - DenseElementsAttr constant; - if (!matchPattern(assign.getValue(), m_Constant(&constant))) { - // Quantized types we can not use the m_Constant matcher. - if (auto constOp = dyn_cast( - assign.getValue().getDefiningOp())) { - constant = constOp.getValue().cast(); - } - } - if (!constant) continue; - - auto name = handle.getSharedName(); - sharedNameToConstant[name] = constant; - sharedNameToLoc[name] = handle.getLoc(); - } - } - - // We also want to grab the list of operations to replace. - for (auto& op : func.getOps()) { - if (auto handle = dyn_cast(op)) - handleOps.push_back(handle); - if (auto assign = dyn_cast(op)) - assignOps.push_back(assign); - if (auto read = dyn_cast(op)) - readOps.push_back(read); - } - } - - // TF::CallOnceOps are no longer needed as we have already extracted their - // state. - SmallVector callOnceOps; - for (auto func : moduleOp.getOps()) { - for (auto init : func.getOps()) { - callOnceOps.push_back(init); - } - } - for (auto op : callOnceOps) op.erase(); - - // Create the ml_program::GlobalOps to store our new global variables. - DenseMap symbolRefMap; - for (auto it : sharedNameToConstant) { - auto name = std::get<0>(it); - auto attribute = std::get<1>(it); - auto locIt = sharedNameToLoc.find(name); - LocationAttr loc = mlir::UnknownLoc(); - if (locIt != sharedNameToLoc.end()) { - loc = std::get<1>(*locIt); - } - - // TODO(suderman): Determine the global type based on all store - // operations. - auto global = builder.create( - loc, name, attribute.getType(), /*is_mutable=*/true, attribute, - nullptr); - global.setPrivate(); - - symbolRefMap[name] = global; - } - - // Replace the assign ops with a global store operation. - for (auto assign : assignOps) { - auto handle = dyn_cast( - assign.getResourceId().getDefiningOp()); - if (!handle) continue; - - Value value = assign.getValue(); - auto globalOpIt = symbolRefMap.find(handle.getSharedName()); - if (globalOpIt == symbolRefMap.end()) { - assign->emitError( - "unable to find corresponding GlobalOp for op's VarHandle"); - continue; - } - auto globalOp = std::get<1>(*globalOpIt); - - builder.setInsertionPoint(assign); - if (globalOp.getType() != value.getType()) { - value = builder - .create( - assign.getLoc(), globalOp.getType(), value) - .getResult(0); - } - - auto globalSymbolRef = SymbolRefAttr::get(context, globalOp.getSymName()); - builder.create(assign.getLoc(), - globalSymbolRef, value); - assign.erase(); - } - - for (auto read : readOps) { - auto handle = dyn_cast( - read.getResourceId().getDefiningOp()); - if (!handle) continue; - - auto globalOpIt = symbolRefMap.find(handle.getSharedName()); - if (globalOpIt == symbolRefMap.end()) continue; - auto globalOp = std::get<1>(*globalOpIt); - - builder.setInsertionPoint(read); - - auto globalSymbolRef = SymbolRefAttr::get(context, globalOp.getSymName()); - Value load = builder.create( - read.getLoc(), globalOp.getType(), globalSymbolRef); - - if (read.getType() != load.getType()) { - load = builder - .create(read.getLoc(), - read.getType(), load) - .getResult(0); - } - read.getResult().replaceAllUsesWith(load); - read.erase(); - } - - for (auto handle : handleOps) { - if (handle.getResult().use_empty()) { - handle.erase(); - } - } - } -}; - -} // namespace - -std::unique_ptr> createLowerGlobalTensorsPass() { - return std::make_unique(); -} - -} // namespace mlir::tosa diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h index 99f9465c8a639c..e41453b0b9af8b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -56,7 +56,6 @@ std::unique_ptr> createLegalizeTFLPass( ArrayRef disabled_patterns = std::nullopt, ArrayRef enabled_patterns = std::nullopt); -std::unique_ptr> createLowerGlobalTensorsPass(); std::unique_ptr> createRetainCallOnceFuncsPass(); std::unique_ptr> createStripModuleMetadataPass(); std::unique_ptr> createConvertTFLUint8Pass(); @@ -68,6 +67,7 @@ std::unique_ptr> createLowerComplexTypesPass(); std::unique_ptr> createStripFunctionMetadataPass(); std::unique_ptr> createStripQuantTypesPass(); std::unique_ptr> createVerifyFullyConvertedPass(); +std::unique_ptr> createLegalizeTFLStatefulPass(); #define GEN_PASS_REGISTRATION #define GEN_PASS_CLASSES @@ -79,12 +79,12 @@ std::unique_ptr> createVerifyFullyConvertedPass(); #define GEN_PASS_DECL_TOSASTRIPQUANTTYPESPASS #define GEN_PASS_DECL_TOSALOWERCOMPLEXTYPESPASS #define GEN_PASS_DECL_TOSADEQUANTIZETFLSOFTMAXPASS -#define GEN_PASS_DECL_LOWERGLOBALTENSORS #define GEN_PASS_DECL_RETAINCALLONCEFUNCS #define GEN_PASS_DECL_STRIPFUNCTIONMETADATA #define GEN_PASS_DECL_STRIPMODULEMETADATA #define GEN_PASS_DECL_VERIFYFULLYCONVERTED #define GEN_PASS_DECL_CONVERTFUNCTIONMETADATA +#define GEN_PASS_DECL_TOSALEGALIZESTATEFULPASS #include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.td b/tensorflow/compiler/mlir/tosa/transforms/passes.td index e623760a4e9aca..3cf7749d875f9d 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.td +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.td @@ -89,12 +89,6 @@ def TosaDequantizeTFLSoftmaxPass : Pass<"tosa-dequantize-tfl-softmax", "mlir::fu let dependentDialects = ["mlir::TFL::TFLDialect", "quantfork::QuantizationForkDialect"]; } -def LowerGlobalTensors : - Pass<"tflite-lower-global-tensors", "mlir::ModuleOp"> { - let summary = "Lowers TFLite global tensors to MLProgram dialect variables."; - let constructor = "createLowerGlobalTensorsPass()"; -} - def RetainCallOnceFuncs : Pass<"tflite-retain-call-once-funcs", "mlir::ModuleOp"> { let summary = "Guarantees that functions used by tfl.call_once are retained."; @@ -125,3 +119,11 @@ def ConvertFunctionMetadata : let constructor = "createConvertFunctionMetadataPass()"; } +def TosaLegalizeTFLStatefulPass : Pass<"tosa-legalize-tfl-stateful-tensors", "mlir::ModuleOp"> { + let summary = "Legalize tfl stateful operators to tosa stateful operators"; + let description = [{ + This pass is legalizing the tfl.call_once op to tosa stateful operators + }]; + let constructor = "createLegalizeTFLStatefulPass()"; + let dependentDialects = ["mlir::TFL::TFLDialect"]; +} diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 21dd107643be51..ae803f5d16dd04 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1017,6 +1017,7 @@ tf_xla_py_strict_test( enable_mlir_bridge = True, python_version = "PY3", tags = [ + "no_aarch64", # TODO(b/315533266) "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", ], @@ -1239,7 +1240,7 @@ tf_xla_py_strict_test( ], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 5, + shard_count = 1, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index a312df10e1f0f5..b54c2e54fa3552 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1061,8 +1061,7 @@ def testMatMul(self): expected=np.array([[4.2384180773686798]], dtype=dtype), rtol=1e-14) - # TODO(phawkins): failing on GPU, no registered kernel. - def DISABLED_testSparseMatMul(self): + def testSparseMatMul(self): # Binary wrappers for sparse_matmul with different hints def SparseMatmulWrapperTF(a, b): return math_ops.sparse_matmul(a, b, a_is_sparse=True) @@ -1073,10 +1072,13 @@ def SparseMatmulWrapperFT(a, b): def SparseMatmulWrapperTT(a, b): return math_ops.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) - self._testMatMul(math_ops.sparse_matmul, self.float_types) - self._testMatMul(SparseMatmulWrapperTF, self.float_types) - self._testMatMul(SparseMatmulWrapperFT, self.float_types) - self._testMatMul(SparseMatmulWrapperTT, self.float_types) + # TODO(b/314165739): SparseMatmul XlaBuilder lowering does not support + # float16 and float64. + float_types = self.float_types - {np.float16, np.float64} + self._testMatMul(math_ops.sparse_matmul, float_types) + self._testMatMul(SparseMatmulWrapperTF, float_types) + self._testMatMul(SparseMatmulWrapperFT, float_types) + self._testMatMul(SparseMatmulWrapperTT, float_types) def testBatchMatMul(self): # Tests with batches of matrices. diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 7c48f5e3ec6518..01142082ae24f5 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -284,8 +284,9 @@ def testRandomNormalIsFinite(self): @parameterized.named_parameters( (f'_{dtype.name}_{seed}', dtype, seed) # pylint: disable=g-complex-comprehension - for seed in ([1, 2], [12, 23], [123, 456], [25252, 314159]) - for dtype in _allowed_types()) + for seed in ([1, 2], [12, 23], [25252, 314159]) + for dtype in _allowed_types() + ) def testDistributionOfStatelessRandomNormal(self, dtype, seed): """Use Anderson-Darling test to test distribution appears normal.""" with self.session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 47cc309b45452b..46f192648ecaa6 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -1270,13 +1270,13 @@ def assert_output_shapes(output, expected_shape): ): reduce_with_shapes((None, 4, 5), (3, None, 5), (13, 4, 5)) - @parameterized.parameters( - random_ops_util.Algorithm.THREEFRY, - random_ops_util.Algorithm.PHILOX, - random_ops_util.Algorithm.AUTO_SELECT, + @parameterized.product( + algorithm=[random_ops_util.Algorithm.THREEFRY, + random_ops_util.Algorithm.PHILOX, + random_ops_util.Algorithm.AUTO_SELECT], + dtype=[np.uint8, np.uint64], ) - def testRngBitGenerator(self, algorithm): - dtype = np.uint64 + def testRngBitGenerator(self, algorithm, dtype): initial_state = array_ops.placeholder(np.uint64, shape=(2,)) shape = (2, 3) res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 448e1cbc9e61ba..b91fb494667c5f 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -3,9 +3,9 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. -# Placeholder: load py_proto_library load("//tensorflow:strict.default.bzl", "py_strict_library") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +# Placeholder: load py_proto_library load( "//tensorflow:tensorflow.bzl", "VERSION", @@ -21,17 +21,18 @@ load( "tf_additional_all_protos", "tf_proto_library", ) -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "cuda_rpath_flags", -) -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") # Platform specific build config load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -103,6 +104,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":trt_logging", @@ -157,6 +160,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":common_utils", @@ -239,6 +244,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":testutils", @@ -318,6 +325,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":common_utils", @@ -354,6 +363,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":testutils", @@ -411,6 +422,7 @@ tf_cuda_library( "utils/trt_execution_context.h", "utils/trt_shape_optimization_profiles.h", ], + features = ["-layering_check"], deps = [ ":common_utils", ":trt_allocator", @@ -431,6 +443,7 @@ tf_cuda_library( name = "trt_logging", srcs = ["utils/trt_logger.cc"], hdrs = ["utils/trt_logger.h"], + features = ["-layering_check"], visibility = ["//visibility:public"], deps = [ ":common_utils", @@ -515,6 +528,7 @@ tf_cuda_library( name = "trt_allocator", srcs = ["utils/trt_allocator.cc"], hdrs = ["utils/trt_allocator.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", @@ -561,6 +575,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":trt_resources", @@ -576,6 +592,7 @@ tf_cuda_library( "convert/logger_registry.h", ], copts = tf_copts(), + features = ["-layering_check"], deps = [ "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -780,6 +797,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":testutils", @@ -816,6 +835,8 @@ tf_cuda_cc_test( "no_cuda_on_cpu_tap", "no_windows", "nomac", + # TODO(b/303453873): Re-enable test once TensorRT has been updated + "notap", ], deps = [ ":testutils", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 2ae20bdfe8f07d..332be3f50bf342 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -4024,7 +4024,7 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertFill) { Reset(); // random data AddTestWeights("dims", {2}, {2, 2}, DT_INT32); - AddTestWeights("value", {1}, {42.0}, tf_type_); + AddTestWeights("value", {1}, {42}, tf_type_); RunValidationAndConversion( node_def, absl::StatusCode::kUnimplemented, convert_not_supported_implicit(node_def.op(), node_def.name())); @@ -4042,16 +4042,19 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertFill) { for (auto output_dims : output_dims_params) { for (auto value_dims : value_dims_params) { Reset(); - std::vector dims_dims = {output_dims.size()}; + std::vector dims_dims = { + static_cast(output_dims.size())}; if (dims_is_tensor) { AddTestTensor("dims", dims_dims, DT_INT32, output_dims, dims_dims); } else { AddTestWeights("dims", dims_dims, output_dims, DT_INT32); } if (value_is_tensor) { - AddTestTensor("value", value_dims, tf_type_, {val}); + AddTestTensor("value", value_dims, tf_type_, + {static_cast(val)}); } else { - AddTestWeights("value", value_dims, {val}, tf_type_); + AddTestWeights("value", value_dims, {static_cast(val)}, + tf_type_); } size_t nb_el = 1; for (auto d : output_dims) { @@ -4084,7 +4087,7 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertRange) { // (a) for all parameters, when shape_idx > 3 // (b) for all parameters, except shape_idx, when shape_idx >= 0 // (c) for none of the shape_idx < 0 - if (shape_idx > 3 || shape_idx >= 0 && shape_idx != i) { + if (shape_idx > 3 || (shape_idx >= 0 && shape_idx != i)) { partial_shape_dims = {1}; } AddTestTensor(name[i], {1}, type[i], value[i], partial_shape_dims); @@ -4140,7 +4143,7 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertRange) { limit_type == DT_INT32 && delta_type == DT_INT32; - if (all_weights || all_integers && !config[2]) { + if (all_weights || (all_integers && !config[2])) { // Reject invalid parameters if delta = 0 and it's passed as a weight. param_value[2] = {0}; set_parameters(param_name, param_value, param_type, config); @@ -9435,8 +9438,8 @@ void OpConverter_Select::RunTest(const string& opName) { std::accumulate(std::begin(expect_dims), std::end(expect_dims), 1, std::multiplies()); - assert(rank_out == expected_out ? expected_out->size() - : rank[use_indices >= 0 ? 0 : 1]); + assert(rank_out == (expected_out ? expected_out->size() + : rank[use_indices >= 0 ? 0 : 1])); expected_output.resize(rank_out); const auto& data_then = *par_value[1]; @@ -9476,7 +9479,7 @@ void OpConverter_Select::RunTest(const string& opName) { const auto nMax = testing_SelectV2 ? 2 : 1; for (int n = 0; n < nMax; n++) { set_parameters(); - if (testing_SelectV2 || same_then_else_shapes && same_cond_chape) { + if (testing_SelectV2 || (same_then_else_shapes && same_cond_chape)) { TestOpConverter(node, exp_dims, OkStatus(), OkStatus(), ElementsAreArray(expected_output)); } else { diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc index fc5fc589211ec1..0e01bcaadb9f63 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc @@ -288,7 +288,7 @@ class ConvertRange : public ConvertFillBase { }; std::string convert_range_error_msg(float start, float limit, float delta) { - constexpr char* format_string = + constexpr const char* format_string = "For parameters (start, limit) = (%.2f, %.2f) " "of the Range operation delta cannot be %s, got %.2f"; return absl::StrFormat(format_string, start, limit, diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/selectv2.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/selectv2.cc index a68ffceb1534a1..4c21e49f12bf0f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/selectv2.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/selectv2.cc @@ -31,10 +31,10 @@ class ConvertSelectBase : public OpConverterBase { public: explicit ConvertSelectBase(const OpConverterParams* params, const std::string& layer_name) - : layer_name_(layer_name), - OpConverterBase( + : OpConverterBase( params, - {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}) {} + {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}), + layer_name_(layer_name) {} static constexpr std::array InputSpec() { return std::array{ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c8dc6721853ccb..5ae1c907f0138d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -14,6 +14,7 @@ load( ) load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -506,9 +507,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:executable_run_options", "@local_xla//xla:protobuf_util", "@local_xla//xla:shape_util", @@ -522,10 +525,11 @@ cc_library( "@local_xla//xla/client:xla_computation", "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/service:computation_placer_hdr", + "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ] + if_libtpu([ ":xla_tpu_backend_registration", - ]), + ]) + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), alwayslink = 1, ) @@ -901,6 +905,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:literal", "@local_xla//xla:literal_util", "@local_xla//xla:statusor", @@ -908,7 +913,7 @@ tf_cc_test( "@local_xla//xla/client:local_client", "@local_xla//xla/client:xla_computation", "@local_xla//xla/service:cpu_plugin", - ], + ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) tf_cc_test( diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index e40ae462cab0f0..81bbbe1955642d 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -32,6 +32,8 @@ class BatchMatMulOp : public XlaOpKernel { explicit BatchMatMulOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_y", &adj_y_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_x", &grad_x_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_y", &grad_y_)); if (ctx->HasAttr("Tout")) { DataType output_type; @@ -48,15 +50,18 @@ class BatchMatMulOp : public XlaOpKernel { tsl::tensor_float_32_execution_enabled() ? xla::PrecisionConfig::DEFAULT : xla::PrecisionConfig::HIGHEST; - auto result = xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_, - MaybeConjugate(ctx->Input(1), adj_y_), adj_y_, - precision, preferred_element_type_); + auto result = + xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_, + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_, precision, + preferred_element_type_, grad_x_, grad_y_); ctx->SetOutput(0, result); } private: bool adj_x_; bool adj_y_; + bool grad_x_; + bool grad_y_; std::optional preferred_element_type_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index ed0930e6243b7b..5a2a6e781cd65d 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific MatMul Op. #include +#include #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -36,9 +37,16 @@ constexpr std::array kMatmulTypes = { class MatMulOp : public XlaOpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false) - : XlaOpKernel(ctx), is_sparse_(is_sparse) { + : XlaOpKernel(ctx), + is_sparse_(is_sparse), + grad_a_(false), + grad_b_(false) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + if (!is_sparse) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_a", &grad_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad_b", &grad_b_)); + } if (is_sparse) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Ta", &a_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tb", &b_type_)); @@ -95,14 +103,16 @@ class MatMulOp : public XlaOpKernel { tsl::tensor_float_32_execution_enabled() ? xla::PrecisionConfig::DEFAULT : xla::PrecisionConfig::HIGHEST; - ctx->SetOutput(0, - xla::BatchDot(a, transpose_a_, b, transpose_b_, precision)); + ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_, precision, + std::nullopt, grad_a_, grad_b_)); } private: bool is_sparse_; bool transpose_a_; bool transpose_b_; + bool grad_a_; + bool grad_b_; DataType a_type_; DataType b_type_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index 098ecf39792e21..cc0cdfc2036fa7 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -428,7 +428,7 @@ REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp); REGISTER_XLA_OP(Name("XlaRngBitGenerator") .CompileTimeConstantInput("algorithm") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", {DT_UINT32, DT_UINT64}), + .TypeConstraint("dtype", {DT_UINT8, DT_UINT32, DT_UINT64}), MlirXlaOpKernel); } // namespace diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 0cfcc0a5a7a78a..c8a4984c356359 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -237,13 +237,13 @@ MlirOptimizationPassState GetPassStateImpl( return MlirOptimizationPassState::FallbackEnabled; case MlirBridgeRolloutPolicy::kDisabledByUser: VLOG(1) << "Skipping MLIR CPU/GPU Bridge, disabled by user."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "tfxla", false, + metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, "disabled_by_user"); return MlirOptimizationPassState::Disabled; default: // This case should never be hit. Added here to be consistent with OSS // implementation. - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "ftxla", false, + metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, "invalid_graph"); return MlirOptimizationPassState::Disabled; } diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 480dc474410359..edb2a40f4d332b 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -816,7 +816,7 @@ REGISTER_OP("XlaRngBitGenerator") .Input("shape: Tshape") .Output("output_key: uint64") .Output("output: dtype") - .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64") + .Attr("dtype: {uint8, int8, int32, int64, uint32, uint64} = DT_UINT64") .Attr("Tshape: {int32, int64} = DT_INT32") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle algorithm; diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 1336d58521404a..01bb69d16ee264 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.h" +#include + #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" @@ -25,20 +27,61 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" +#include "tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { +class ConvertGraphDefToXlaWithTF32Disabled : public ::testing::Test { + public: + ConvertGraphDefToXlaWithTF32Disabled() { + tsl::enable_tensor_float_32_execution(false); + } + ~ConvertGraphDefToXlaWithTF32Disabled() override { + tsl::enable_tensor_float_32_execution(true); + } +}; + AttrValue TypeAttrValue(DataType type) { AttrValue attr_value; SetAttrValue(type, &attr_value); return attr_value; } +AttrValue StringAttrValue(StringPiece str) { + AttrValue attr_value; + SetAttrValue(str, &attr_value); + return attr_value; +} + +AttrValue IntAttrValue(int i) { + AttrValue attr_value; + SetAttrValue(i, &attr_value); + return attr_value; +} + +AttrValue IntVectorAttrValue(const std::vector& ints) { + AttrValue attr_value; + SetAttrValue(ints, &attr_value); + return attr_value; +} + +TensorShapeProto TensorShape(const std::vector& dims) { + TensorShapeProto shape; + for (int i = 0; i < dims.size(); ++i) { + shape.add_dim(); + shape.mutable_dim(i)->set_size(dims[i]); + } + return shape; +} + GraphDef SumGraph() { GraphDef graph_def; NodeDef* x = graph_def.add_node(); @@ -97,6 +140,190 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } +GraphDef EinsumGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* einsum = graph_def.add_node(); + einsum->set_name("einsum"); + einsum->set_op("Einsum"); + einsum->add_input("x"); + einsum->add_input("y"); + (*einsum->mutable_attr())["equation"] = StringAttrValue("ij,jk->ik"); + (*einsum->mutable_attr())["T"] = TypeAttrValue(DT_FLOAT); + (*einsum->mutable_attr())["N"] = IntAttrValue(2); + return graph_def; +} + +tf2xla::Config EinsumConfig() { + tf2xla::Config config; + + tf2xla::Feed* x_feed = config.add_feed(); + x_feed->mutable_id()->set_node_name("x"); + *x_feed->mutable_shape() = TensorShape({2, 2}); + + tf2xla::Feed* y_feed = config.add_feed(); + y_feed->mutable_id()->set_node_name("y"); + *y_feed->mutable_shape() = TensorShape({2, 2}); + + config.add_fetch()->mutable_id()->set_node_name("einsum"); + return config; +} + +TEST(ConvertGraphDefToXla, EinsumIsConvertedToDotWithDefaultPrecision) { + GraphDef graph_def = EinsumGraph(); + tf2xla::Config config = EinsumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_dots = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "dot") { + num_dots++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::DEFAULT); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::DEFAULT); + } + } + } + EXPECT_EQ(num_dots, 1); +} + +TEST_F(ConvertGraphDefToXlaWithTF32Disabled, + EinsumIsConvertedToDotWithHighestPrecision) { + GraphDef graph_def = EinsumGraph(); + tf2xla::Config config = EinsumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_dots = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "dot") { + num_dots++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::HIGHEST); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::HIGHEST); + } + } + } + EXPECT_EQ(num_dots, 1); +} + +GraphDef Conv2DGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* einsum = graph_def.add_node(); + einsum->set_name("conv2d"); + einsum->set_op("Conv2D"); + einsum->add_input("x"); + einsum->add_input("y"); + (*einsum->mutable_attr())["T"] = TypeAttrValue(DT_FLOAT); + (*einsum->mutable_attr())["padding"] = StringAttrValue("VALID"); + (*einsum->mutable_attr())["strides"] = IntVectorAttrValue({1, 1, 1, 1}); + return graph_def; +} + +tf2xla::Config Conv2DConfig() { + tf2xla::Config config; + tf2xla::Feed* x_feed = config.add_feed(); + x_feed->mutable_id()->set_node_name("x"); + *x_feed->mutable_shape() = TensorShape({1, 1, 2, 2}); + + tf2xla::Feed* y_feed = config.add_feed(); + y_feed->mutable_id()->set_node_name("y"); + *y_feed->mutable_shape() = TensorShape({1, 1, 2, 2}); + config.add_fetch()->mutable_id()->set_node_name("conv2d"); + return config; +} + +TEST(ConvertGraphDefToXla, Conv2DIsConvertedToConvolutionWithDefaultPrecision) { + GraphDef graph_def = Conv2DGraph(); + tf2xla::Config config = Conv2DConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_convolutions = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "convolution") { + num_convolutions++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::DEFAULT); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::DEFAULT); + } + } + } + EXPECT_EQ(num_convolutions, 1); +} + +TEST_F(ConvertGraphDefToXlaWithTF32Disabled, + Conv2DIsConvertedToConvolutionWithHighestPrecision) { + GraphDef graph_def = Conv2DGraph(); + tf2xla::Config config = Conv2DConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_convolutions = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "convolution") { + num_convolutions++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::HIGHEST); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::HIGHEST); + } + } + } + EXPECT_EQ(num_convolutions, 1); +} + TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { GraphDef graph_def = SumGraph(); tf2xla::Config config = SumConfig(); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 917d775c80011d..dc4109f52f96b6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -228,4 +228,34 @@ int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { return LookupNameIndex(name, result_names_); } +const char* XlaCompiledCpuFunction::GetArgName(const int index) const { + assert(arg_names_ != nullptr); + if (index < 0 || index >= num_args_) { + std::cerr << "XlaCompiledCpuFunction::GetArgName: index '" << index + << "' out of range [0, " << num_args_ << "].\n"; + return nullptr; + } + return arg_names_[index]; +} + +const char* XlaCompiledCpuFunction::GetVariableName(int index) const { + assert(variable_names_ != nullptr); + if (index < 0 || index >= num_variables_) { + std::cerr << "XlaCompiledCpuFunction::GetVariableName: index '" << index + << "' out of range [0, " << num_variables_ << ").\n"; + return nullptr; + } + return variable_names_[index]; +} + +const char* XlaCompiledCpuFunction::GetResultName(int index) const { + assert(result_names_ != nullptr); + if (index < 0 || index >= num_results_) { + std::cerr << "XlaCompiledCpuFunction::GetResultName: index '" << index + << "' out of range [0, " << num_results_ << ").\n"; + return nullptr; + } + return result_names_[index]; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 871ddb32d2652a..d03f06e14f5bce 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -294,6 +294,18 @@ class XlaCompiledCpuFunction { // Recommended usage is to capture this in a variable for re-use. int LookupResultIndex(const string& name) const; + // Returns the name of the argument at `index`. + // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. + const char* GetArgName(int index) const; + + // Returns the name of the variable at `index`. + // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. + const char* GetVariableName(int index) const; + + // Returns the name of the result at `index`. + // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. + const char* GetResultName(int index) const; + // Returns the shape of the args and results. May return nullptr if the // program shape isn't available. const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index aa2c761ccb6e26..bb8b29de5b9acf 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include +#include #include #include #include @@ -27,9 +28,11 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" @@ -52,6 +55,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/protobuf_util.h" +#include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/device.h" @@ -72,6 +76,7 @@ limitations under the License. #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { @@ -1435,6 +1440,38 @@ class DummyStackTrace : public AbstractStackTrace { StackFrame({"dummy_file_name", 10, "dummy_function_name"})}; }; +namespace { + +// Add precisions configs to the HLO module to avoid TensorFloat32 computations +// in XLA. +// +// Some operations, such as Einsum are converted through MlirXlaOpKernel, which +// doesn't set the precisions, so we set them all here. +// +// TODO(tdanyluk): We may want to restrict this logic to only set the operand +// precision for F32 operands. (Historically, it was set without regard to +// operand type in other parts of TF2XLA.) +void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { + static constexpr std::array kOpsPossiblyUsingTF32 = { + "dot", "convolution"}; + + xla::PrecisionConfig precision_config; + precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); + precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); + + for (xla::HloComputationProto& computation : *module.mutable_computations()) { + for (xla::HloInstructionProto& instruction : + *computation.mutable_instructions()) { + if (absl::c_find(kOpsPossiblyUsingTF32, instruction.opcode()) != + kOpsPossiblyUsingTF32.end()) { + *instruction.mutable_precision_config() = precision_config; + } + } + } +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -1571,6 +1608,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, *result->host_compute_metadata.add_host_to_device() = recv; } + if (!tsl::tensor_float_32_execution_enabled()) { + IncreasePrecisionsToAvoidTF32(*result->computation->mutable_proto()); + } + VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; VLOG(2) << "XLA output shape: " diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index ff6b72ca562976..ad65c1708794fd 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -213,6 +213,26 @@ TEST(XlaJitCompiledCpuFunction, Sum) { EXPECT_EQ(0, function.num_variables()); EXPECT_EQ(function.LookupVariableIndex("x"), -1); + // Expect that name and index lookups match. + for (int i = 0; i < function.num_args(); ++i) { + const char* name = function.GetArgName(i); + ASSERT_NE(name, nullptr); + const int roundtrip_i = function.LookupArgIndex(name); + EXPECT_EQ(roundtrip_i, i) << " name= " << name; + } + for (int i = 0; i < function.num_results(); ++i) { + const char* name = function.GetResultName(i); + ASSERT_NE(name, nullptr); + const int roundtrip_i = function.LookupResultIndex(name); + EXPECT_EQ(roundtrip_i, i) << " name= " << name; + } + // Expect correct handling of invalid indices. + EXPECT_EQ(function.GetArgName(-1), nullptr); + EXPECT_EQ(function.GetArgName(function.num_args()), nullptr); + EXPECT_EQ(function.GetResultName(-1), nullptr); + EXPECT_EQ(function.GetResultName(function.num_results()), nullptr); + EXPECT_EQ(function.GetVariableName(0), nullptr); + // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); @@ -263,6 +283,11 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { EXPECT_EQ(1, function.num_variables()); EXPECT_EQ(function.LookupVariableIndex("myvar"), 1); + const char* name = function.GetVariableName(0); + EXPECT_EQ(std::string(name), "myvar"); + EXPECT_EQ(function.GetVariableName(1), nullptr); + EXPECT_EQ(function.GetVariableName(-1), nullptr); + // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD deleted file mode 100644 index 58cf8a80e3e751..00000000000000 --- a/tensorflow/compiler/xrt/BUILD +++ /dev/null @@ -1,172 +0,0 @@ -# Description: Operations defined for XRT - -# Placeholder: load py_proto_library -load("//tensorflow:strict.default.bzl", "py_strict_library") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load( - "//tensorflow:tensorflow.bzl", - "tf_gen_op_wrapper_py", -) -load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_strict_library", "tf_gen_op_libs") -load( - "//tensorflow/core/platform:build_config.bzl", - "tf_proto_library", -) -load( - "@local_config_cuda//cuda:build_defs.bzl", - "if_cuda", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//learning/brain:__subpackages__", - "//tensorflow/compiler/xrt:__subpackages__", - ], - licenses = ["notice"], -) - -tf_proto_library( - name = "xrt_proto", - srcs = ["xrt.proto"], - cc_api_version = 2, - protodeps = [ - "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", - "@local_xla//xla:xla_data_proto", - "@local_xla//xla:xla_proto", - "@local_xla//xla/service:hlo_proto", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "xrt_tpu_utils", - srcs = [ - "xrt_tpu_device.cc", - ], - hdrs = [ - "xrt_tpu_device.h", - ], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/jit:xla_device", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/tpu:tpu_configuration", - "@local_xla//xla/client:local_client", - "@local_xla//xla/stream_executor/tpu:tpu_node_context", - ], -) - -cc_library( - name = "xrt_utils", - srcs = [ - "xrt_compilation_cache.cc", - "xrt_device.cc", - "xrt_memory_manager.cc", - "xrt_metrics.cc", - "xrt_state.cc", - "xrt_util.cc", - ], - hdrs = [ - "xrt_compilation_cache.h", - "xrt_device.h", - "xrt_memory_manager.h", - "xrt_metrics.h", - "xrt_refptr.h", - "xrt_state.h", - "xrt_util.h", - ], - copts = if_cuda(["-DGOOGLE_CUDA=1"]), - visibility = ["//visibility:public"], - deps = [ - ":xrt_proto_cc", - "//tensorflow/compiler/jit:xla_device", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/common_runtime/gpu:gpu_runtime", - "//tensorflow/core/platform:regexp", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@local_xla//xla:debug_options_flags", - "@local_xla//xla:literal", - "@local_xla//xla:shape_util", - "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", - "@local_xla//xla:types", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla:xla_proto_cc", - "@local_xla//xla/client:local_client", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:backend", - "@local_xla//xla/service:executable", - "@local_xla//xla/service:shaped_buffer", - "@local_xla//xla/stream_executor", - "@local_xla//xla/stream_executor:device_memory_allocator", - "@local_xla//xla/stream_executor/integrations:tf_allocator_adapter", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "xrt_compile_ops", - "xrt_state_ops", - "xrt_execute_op", - ], - deps = [ - "//tensorflow/compiler/jit:common", - "//tensorflow/core:lib", - ], -) - -tf_gen_op_wrapper_py( - name = "xrt_ops_wrapper_py", - out = "xrt_ops.py", - extra_py_deps = [ - "//tensorflow/python:pywrap_tfe", - "//tensorflow/python/util:dispatch", - "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:tf_export", - ], - py_lib_rule = py_strict_library, - deps = [ - ":xrt_compile_ops_op_lib", - ":xrt_execute_op_op_lib", - ":xrt_state_ops_op_lib", - ], -) - -tf_custom_op_py_strict_library( - name = "xrt_ops", - kernels = ["//tensorflow/compiler/xrt/kernels:xrt_ops"], - visibility = ["//visibility:public"], - deps = [ - ":xrt_ops_wrapper_py", - ], -) - -cc_library( - name = "xrt_server", - visibility = ["//visibility:public"], - deps = [ - ":xrt_compile_ops_op_lib", - ":xrt_execute_op_op_lib", - ":xrt_state_ops_op_lib", - "//tensorflow/compiler/xrt/kernels:xrt_ops", - ], -) - -# copybara:uncomment_begin(google-only) -# py_proto_library( -# name = "xrt_proto_py_pb2", -# api_version = 2, -# visibility = ["//visibility:public"], -# deps = [":xrt_proto"], -# ) -# copybara:uncomment_end diff --git a/tensorflow/compiler/xrt/cc/BUILD b/tensorflow/compiler/xrt/cc/BUILD deleted file mode 100644 index 9783aeaafa0815..00000000000000 --- a/tensorflow/compiler/xrt/cc/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tensorflow:tensorflow.default.bzl", "tf_gen_op_wrappers_cc") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -tf_gen_op_wrappers_cc( - name = "xrt_ops", - op_lib_names = [ - "xrt_compile_ops", - "xrt_state_ops", - "xrt_execute_op", - ], - pkg = "//tensorflow/compiler/xrt", -) diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD deleted file mode 100644 index e4c4075a392c3a..00000000000000 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ /dev/null @@ -1,146 +0,0 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//learning/brain:__subpackages__", - "//tensorflow/compiler/xrt:__subpackages__", - ], - licenses = ["notice"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/tf2xla:friends", - ], -) - -WITH_TPU_SUPPORT = "//tensorflow:with_tpu_support" - -DEFAULT = "//conditions:default" - -cc_library( - name = "xrt_state_ops", - hdrs = ["xrt_state_ops.h"], - visibility = [":friends"], - deps = [ - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xrt:xrt_proto_cc", - "//tensorflow/compiler/xrt:xrt_utils", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "@local_xla//xla:literal", - "@local_xla//xla:shape_util", - "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:local_client", - "@local_xla//xla/service:computation_placer", - ], - alwayslink = 1, -) - -cc_library( - name = "xrt_tpu_ops", - srcs = [ - "tpu_compile_ops.cc", - "tpu_execute_op.cc", - "tpu_state_op.cc", - ], - visibility = [":friends"], - deps = [ - ":xrt_state_ops", - "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xrt:xrt_proto_cc", - "//tensorflow/compiler/xrt:xrt_tpu_utils", - "//tensorflow/compiler/xrt:xrt_utils", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/tpu:tpu_configuration", - "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/core/tpu:tpu_execute", - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_entry", - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_interface", - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_key", - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_lookup", - "//tensorflow/core/tpu/kernels:tpu_compile_op_common", - "//tensorflow/core/tpu/kernels:tpu_compile_op_hdrs", - "//tensorflow/core/tpu/kernels:tpu_mesh_state_interface", - "//tensorflow/core/tpu/kernels:tpu_op_consts", - "//tensorflow/core/tpu/kernels:tpu_op_util", - "//tensorflow/core/tpu/kernels:tpu_program_group", - "//tensorflow/core/tpu/kernels:tpu_program_group_interface", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_xla//xla:debug_options_flags", - "@local_xla//xla:shape_util", - "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/client:local_client", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:compiler", - "@local_xla//xla/service:computation_placer", - "@local_xla//xla/service:dump", - "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/stream_executor", - "@local_xla//xla/stream_executor/tpu:tpu_api", - ], - alwayslink = 1, -) - -cc_library( - name = "xrt_ops", - srcs = [ - "xrt_compile_ops.cc", - "xrt_execute_op.cc", - "xrt_state_ops.cc", - ], - visibility = [":friends"], - deps = select({ - WITH_TPU_SUPPORT: [":xrt_tpu_ops"], - DEFAULT: [], - }) + [ - ":xrt_state_ops", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xrt:xrt_compile_ops_op_lib", - "//tensorflow/compiler/xrt:xrt_execute_op_op_lib", - "//tensorflow/compiler/xrt:xrt_proto_cc", - "//tensorflow/compiler/xrt:xrt_state_ops_op_lib", - "//tensorflow/compiler/xrt:xrt_utils", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - "@local_xla//xla:literal_util", - "@local_xla//xla:shape_util", - "@local_xla//xla:status_macros", - "@local_xla//xla:statusor", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:local_client", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:compiler", - "@local_xla//xla/service:computation_placer", - "@local_xla//xla/service/gpu:gpu_executable_run_options", - "@local_xla//xla/stream_executor", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/xrt/kernels/tpu_compile_ops.cc b/tensorflow/compiler/xrt/kernels/tpu_compile_ops.cc deleted file mode 100644 index 8c3d3aa7300208..00000000000000 --- a/tensorflow/compiler/xrt/kernels/tpu_compile_ops.cc +++ /dev/null @@ -1,277 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for compiling XLA computations and managing handles that refer to -// them. - -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "xla/client/client_library.h" -#include "xla/client/compile_only_client.h" -#include "xla/client/xla_computation.h" -#include "xla/debug_options_flags.h" -#include "xla/service/compiler.h" -#include "xla/service/dump.h" -#include "xla/service/hlo.pb.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/tpu/tpu_api.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" -#include "tensorflow/compiler/xrt/xrt_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/monitoring/timed.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_op.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_op_consts.h" -#include "tensorflow/core/tpu/kernels/tpu_op_util.h" -#include "tensorflow/core/tpu/kernels/tpu_program_group.h" -#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" -#include "tensorflow/core/tpu/tpu_configuration.h" -#include "tensorflow/core/tpu/tpu_defs.h" - -namespace tensorflow { - -class XRTCompileOp : public OpKernel { - public: - explicit XRTCompileOp(OpKernelConstruction* ctx); - ~XRTCompileOp() override; - XRTCompileOp(const XRTCompileOp&) = delete; - XRTCompileOp& operator=(const XRTCompileOp&) = delete; - - void Compute(OpKernelContext* ctx) override; - - private: - Status Compile(const XLA_TpuMeshState* xla_mesh_state, - const xrt::XLAComputation& computation_proto, - tensorflow::tpu::TpuProgramGroupInterface* tpu_program_group); -}; - -XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - -Status XRTCompileOp::Compile( - const XLA_TpuMeshState* xla_mesh_state, - const xrt::XLAComputation& computation_proto, - tensorflow::tpu::TpuProgramGroupInterface* tpu_program_group) { - return tensorflow::tpu::TpuProgramGroup::CompileAndBuild( - computation_proto, xla_mesh_state, tpu_program_group); -} - -tpu::TpuCompilationCacheKey CompilationCacheKey( - const xrt::XLAComputation& computation, - tensorflow::tpu::TpuMeshStateInterface* mesh_state, int num_replicas, - int num_cores_per_replica) { - string computation_serialized; - CHECK(SerializeToStringDeterministic(computation, &computation_serialized)); - tpu::TPUCompileMetadataProto metadata; - metadata.set_num_replicas(num_replicas); - metadata.set_num_cores_per_replica(num_cores_per_replica); - const tpu::TpuCompilationCacheKey key = CreateCompilationCacheKey( - "compile", 0, tensorflow::Fingerprint64(computation_serialized), {}, {}, - metadata, *mesh_state); - return key; -} - -void ExitCountdown(Env* env, std::shared_ptr> done) { - const int kSleepSeconds = 300; - LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds - << " seconds to give time for TPUCompileOp to finished."; - env->SleepForMicroseconds(kSleepSeconds * 1000000); - if (done->load()) { - // If the TpuCompileOp has finished, then terminate peacefully. - return; - } - - LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This " - << "termination is to ensure a consistent state."; - std::exit(42); -} - -void XRTCompileOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XRTCompileOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetCompileCell()); - - std::shared_ptr> done(new std::atomic(false)); - CancellationToken token = - ctx->cancellation_manager()->get_cancellation_token(); - const bool already_cancelled = - !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() { - if (stream_executor::tpu::OpsApiFn() - ->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) { - return; - } - - // Sleep and exit in another thread so the cancellation manager can - // continue running callbacks. - Env* env = ctx->env(); - env->SchedClosure([env, done]() { ExitCountdown(env, done); }); - }); - - // If the RPC was cancelled before we registered the cancellation callback, - // don't compile the TPU program. - OP_REQUIRES(ctx, !already_cancelled, - absl::CancelledError("RPC cancelled, not compiling TPU program")); - - // We only want to abort the process if a cancellation actually occurs during - // compilation; we must deregister the callback in the success case. It - // doesn't hurt to also deregister the callback in the failure case; the - // CancellationManager ensures that already-registered callbacks will be run - // once cancellation has started. - auto cancellation_cleanup = absl::MakeCleanup([ctx, token, done] { - ctx->cancellation_manager()->DeregisterCallback(token); - done->store(true); - }); - - VLOG(1) << "Retrieving pod state"; - // Retrieve the topology from the resource manager - ResourceMgr* rm = GetTPUConfigResourceMgr(); - tensorflow::tpu::TpuMeshStateInterface* mesh_state; - OP_REQUIRES_OK(ctx, - rm->Lookup(rm->default_container(), - tensorflow::tpu::kTpuMeshStateInterfaceResourceName, - &mesh_state)); - core::ScopedUnref mesh_state_unref(mesh_state); - - const Tensor& computation_input = ctx->input(0); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(computation_input.shape()), - absl::InternalError("computation input should be a string scalar")); - - xrt::XLAComputation computation_proto; - OP_REQUIRES( - ctx, - computation_proto.ParseFromString(computation_input.scalar()()), - absl::InvalidArgumentError( - "Unable to parse computation input to XLAComputation")); - - const xrt::XLAComputationConfig& config = computation_proto.config(); - int num_replicas = config.num_replicas() ? config.num_replicas() : 1; - CHECK_GT(num_replicas, 0); - int num_cores_per_replica = - config.num_cores_per_replica() ? config.num_cores_per_replica() : 1; - - const tpu::TpuCompilationCacheKey key = CompilationCacheKey( - computation_proto, mesh_state, num_replicas, num_cores_per_replica); - - // Process-wide cache of Tpu executables. - tpu::TpuCompilationCacheInterface* cache; - OP_REQUIRES_OK(ctx, rm->Lookup( - rm->default_container(), - tpu::kCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); - - int64_t uid; - std::vector proto_key; - std::vector shard_key; - std::vector may_modify_variables; - absl::Span hlo_metadata; - OP_REQUIRES_OK( - ctx, cache->CompileIfKeyAbsent( - key, /*session_metadata=*/nullptr, - /*per_step_ref_holder=*/nullptr, &uid, &proto_key, &shard_key, - &may_modify_variables, &hlo_metadata, - [&](tpu::TpuProgramGroupInterface* tpu_program_group) { - VLOG(1) << "Compiling TPU executable"; - return Compile(mesh_state->data(), computation_proto, - tpu_program_group); - })); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = uid; - ctx->set_output(0, output); - - Tensor program_shape_output(DT_STRING, TensorShape({num_cores_per_replica})); - for (int64_t i = 0; i < num_cores_per_replica; ++i) { - xla::ProgramShapeProto program_shape = - hlo_metadata[i]->hlo_module().host_program_shape(); - program_shape_output.vec()(i) = program_shape.SerializeAsString(); - } - ctx->set_output(1, program_shape_output); -} - -XRTCompileOp::~XRTCompileOp() = default; - -class XRTReleaseCompilationRefOp : public OpKernel { - public: - explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx); - ~XRTReleaseCompilationRefOp() override; - XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete; - XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) = - delete; - - void Compute(OpKernelContext* ctx) override; -}; - -XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp( - OpKernelConstruction* ctx) - : OpKernel(ctx) {} - -XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default; - -void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell()); - ResourceMgr* rm = GetTPUConfigResourceMgr(); - OP_REQUIRES(ctx, rm != nullptr, absl::InternalError("No resource manager.")); - - // Process-wide cache of Tpu executables. - tpu::TpuCompilationCacheInterface* cache; - OP_REQUIRES_OK(ctx, rm->Lookup( - rm->default_container(), - tpu::kCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); - - const Tensor& keys_tensor = ctx->input(0); - auto flat_keys = keys_tensor.flat(); - for (int64_t i = 0; i < flat_keys.size(); ++i) { - int64_t key = flat_keys(i); - OP_REQUIRES_OK(ctx, cache->Release(key)); - VLOG(2) << "Released computation handle " << key; - } -} - -REGISTER_KERNEL_BUILDER(Name("XRTCompile") - .Device(DEVICE_TPU_NODE) - .HostMemory("computation") - .HostMemory("handle"), - XRTCompileOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle") - .Device(DEVICE_TPU_NODE) - .HostMemory("handle"), - XRTReleaseCompilationRefOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/tpu_execute_op.cc b/tensorflow/compiler/xrt/kernels/tpu_execute_op.cc deleted file mode 100644 index 1073a103c8369a..00000000000000 --- a/tensorflow/compiler/xrt/kernels/tpu_execute_op.cc +++ /dev/null @@ -1,490 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/service/computation_placer.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/compiler/xrt/xrt_memory_manager.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" -#include "tensorflow/compiler/xrt/xrt_state.h" -#include "tensorflow/compiler/xrt/xrt_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/monitoring/timed.h" -#include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" -#include "tensorflow/core/tpu/kernels/tpu_op_consts.h" -#include "tensorflow/core/tpu/kernels/tpu_program_group.h" -#include "tensorflow/core/tpu/tpu_configuration.h" -#include "tensorflow/core/tpu/tpu_defs.h" -#include "tensorflow/core/tpu/tpu_execute.h" - -namespace tensorflow { -namespace { - -using tensorflow::tpu::CompilationCacheEntryRef; -using tensorflow::tpu::TpuCompilationCacheEntry; -using tensorflow::tpu::TpuCompilationCacheLookup; -using GetBufferFunction = - std::function>()>; - -// Looks up the input `key` in the compilation cache. -Status GetComputationCacheEntry( - ResourceMgr* rm, int64_t key, int core_index_in_replica, - std::unique_ptr* entry) { - profiler::TraceMe trace_me("XRTExecuteOp::LookupProto", /*level=*/2); - TpuCompilationCacheLookup* proto_lookup; - TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(), - tpu::kCompiledProtoCacheResourceName, - &proto_lookup)); - core::ScopedUnref lookup_unref(proto_lookup); - TF_RETURN_IF_ERROR(proto_lookup->Lookup(key, core_index_in_replica, entry)); - return OkStatus(); -} - -std::vector GetDynamicInputInfo( - const TPUExecutableInfoProto& executable_proto) { - std::vector input_is_dynamic; - input_is_dynamic.reserve(executable_proto.input_shapes().size()); - for (int64_t i = 0; i < executable_proto.input_shapes().size(); ++i) { - input_is_dynamic.push_back( - !xla::Shape(executable_proto.input_shapes(i)).is_static()); - } - return input_is_dynamic; -} - -xla::StatusOr>> GetChainedOpInputs( - const xrt::XRTChainedExecuteOp& op, - absl::Span> op_inputs, - const TPUExecutableInfoProto& executable_proto) { - if (op.inputs_size() != executable_proto.input_shapes_size()) { - return errors::InvalidArgument( - "Number of inputs does not match executable proto input shapes: ", - op.inputs_size(), " vs. ", executable_proto.input_shapes_size()); - } - - std::vector> input_tuples; - input_tuples.reserve(op.inputs_size()); - for (int i = 0; i < op.inputs_size(); ++i) { - auto& input = op.inputs(i); - const RefPtr& tuple = op_inputs[i]; - // Thanks to the greatness of proto3, there is no way to query for - // explicitly set fields, so the default for output_index (zero) means no - // sub-index. As consequence, the real index is output_index - 1. - if (input.output_index() == 0) { - input_tuples.push_back(tuple); - } else { - XRTTupleAllocation* sub_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - tuple.get(), {input.output_index() - 1}, &sub_tuple, - /*alias_parent_allocation=*/true)); - input_tuples.emplace_back(sub_tuple); - } - if (!InputShapeMatches(xla::Shape(executable_proto.input_shapes(i)), - input_tuples.back()->on_host_shape())) { - return errors::InvalidArgument( - "Run-time shape mismatch for XRTExecute argument[", i, "] (", - op.computation_handle(), "). Expected ", - executable_proto.input_shapes(i).DebugString(), "; got ", - tuple->on_host_shape().DebugString()); - } - } - return std::move(input_tuples); -} - -xla::StatusOr GetExecutableAliasConfig( - const tpu::TpuProgramGroup* tpu_program_group, xla::Backend* const backend, - int core_index) { - const TPUExecutableInfoProto& executable = - tpu_program_group->executable_info(core_index); - return xla::HloInputOutputAliasConfig::CreateFromProto( - backend->transfer_manager()->HostShapeToDeviceShape( - xla::Shape(executable.output_shape())), - tpu_program_group->hlo_metadata(core_index) - ->hlo_module() - .input_output_alias()); -} - -xla::StatusOr> AllocateOutputTuple( - tpu::TpuNodeContext* node_context, se::Stream* stream, - absl::Span> input_tuples, - const xla::HloInputOutputAliasConfig& input_output_alias, - xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) { - auto output_shaped_buffer = output_scoped_buffer.release(); - - xla::Shape output_device_shape = output_shaped_buffer.on_device_shape(); - if (!output_device_shape.is_static()) { - TF_RETURN_IF_ERROR( - node_context->backend()->transfer_manager()->ReadDynamicShapes( - stream, &output_shaped_buffer, &output_device_shape)); - } - - XRTTupleAllocation* output_tuple; - xla::Shape output_host_shape = - xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape); - - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - output_shaped_buffer, output_host_shape, output_device_shape, - node_context->backend(), device_ordinal, &output_tuple, - node_context->backend()->memory_allocator())); - RefPtr output_tuple_ptr(output_tuple); - - // If the input tuples had to release some buffers in order to provide the - // proper temporary ownership transfer, we patch the holes here by alising the - // buffers from the result tuple. The device address we patch back here, will - // essentially be the same one we carved out in the DoWork() function. - TF_RETURN_IF_ERROR( - RebuildOutputAliases(output_tuple_ptr, input_tuples, input_output_alias)); - - return std::move(output_tuple_ptr); -} - -Status AllocateOutputTensors( - OpKernelContext* context, XRTMemoryManager* memory_manager, - tpu::TpuNodeContext* node_context, se::Stream* stream, - const xrt::XRTExecutionConfig& config_proto, - const TPUExecutableInfoProto& executable_proto, - absl::Span> input_tuples, - const xla::HloInputOutputAliasConfig& input_output_alias, - xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) { - TF_ASSIGN_OR_RETURN( - RefPtr output_tuple, - AllocateOutputTuple(node_context, stream, input_tuples, - input_output_alias, std::move(output_scoped_buffer), - device_ordinal)); - return CreateExecuteOutput(context, memory_manager, std::move(output_tuple), - config_proto.return_exploded_tuple()); -} - -xla::StatusOr RunExecutable( - OpKernelContext* context, tpu::TpuNodeContext* node_context, - const TPUExecutableInfoProto& executable, - std::vector arguments, const string& execution_id, - const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group, - xla::Backend* const backend, se::Stream* stream, int core_index, - int device_ordinal, string rendezvous_key_base) { - profiler::TraceMe trace_me("RunExecutable", /*level=*/2); - - // se::StreamExecutor* executor = node->stream_executor(); - - std::unique_ptr device_assignment; - if (executable.has_device_assignment()) { - TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize( - executable.device_assignment())); - } - // Ideally this should be the host-to-device stream from XlaDeviceContext. - // The particular anti-dependency this is avoiding (why we need a separate - // transfer stream) is between the executable writing tuple tables and - // TPUExecute()'s deregister_stream; if they come from the same stream pool - // antidependencies will occur. XlaBackend has a different pool of streams - // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these - // will never refer to the same stream. - TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr, - backend->BorrowStream(device_ordinal)); - const TPUHostTransferInfoProto& host_transfer_info = - tpu_program_group->host_transfer_info(core_index); - TF_ASSIGN_OR_RETURN( - xla::ExecutionOutput output, - TPUExecute(executable, host_transfer_info, - *tpu_program_group->hlo_metadata(core_index), - std::move(arguments), rendezvous_key_base, rng_seed, - node_context, device_assignment.get(), - context->cancellation_manager(), context, stream, - transfer_stream_ptr.get(), - tpu_program_group->tpu_program(core_index))); - - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - - return output; -} - -xla::StatusOr ExecuteTPUProgram( - OpKernelContext* context, tpu::TpuNodeContext* node_context, - XRTMemoryManager* memory_manager, const TPUExecutableInfoProto& executable, - const GetBufferFunction& get_buffers_fn, const string& execution_id, - const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group, - xla::Backend* const backend, se::Stream* stream, int core_index, - int device_ordinal, string rendezvous_key_base) { - auto runfn = [&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(auto arguments, get_buffers_fn()); - return RunExecutable(context, node_context, executable, - std::move(arguments), execution_id, rng_seed, - tpu_program_group, backend, stream, core_index, - device_ordinal, rendezvous_key_base); - }; - return memory_manager->Run( - runfn, backend, device_ordinal, /*requested_free_size=*/0, - backend->memory_allocator()); -} - -// XRTExecuteOp - -class XRTExecuteOp : public AsyncOpKernel { - public: - explicit XRTExecuteOp(OpKernelConstruction* context); - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override; - - private: - Status DoWork(OpKernelContext* context); -}; - -XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context) - : AsyncOpKernel(context, /* is_deferred = */ true) {} - -void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) { - // Schedule onto the default queue, for unbounded concurrency. See b/73520706 - OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); - done(); -} - -Status XRTExecuteOp::DoWork(OpKernelContext* context) { - VLOG(1) << "XRTExecuteOp::Compute"; - - const XlaDevice::Metadata* metadata; - TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata)); - const int device_ordinal = metadata->device_ordinal(); - // We are guaranteed that the object underlying TpuNodeContext won't be - // deleted out from under us, while node_context is alive. - TF_ASSIGN_OR_RETURN(std::unique_ptr node_context, - tpu::TpuNodeContext::Create(device_ordinal)); - xla::Backend* const backend = node_context->backend(); - se::Stream* stream = context->op_device_context()->stream(); - - auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell()); - profiler::TraceMe trace_me( - [context] { - return profiler::TraceMeEncode("TpuExecuteOp", - {{"step_id", context->step_id()}}); - }, - /*level=*/2); - profiler::TraceMe trace_me_init("XRTExecuteOp::Init", /*level=*/2); - - auto* rm = GetTPUConfigResourceMgr(); - TF_RET_CHECK(rm != nullptr); - - const Tensor& execution_input = context->input(0); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape())); - int64_t compilation_handle = execution_input.scalar()(); - - const Tensor& execution_config = context->input(1); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); - xrt::XRTExecutionConfig config_proto; - TF_RET_CHECK( - config_proto.ParseFromString(execution_config.scalar()())); - - int core_index_in_replica = config_proto.core_index_in_replica(); - bool release_inputs = config_proto.release_input_handles(); - bool release_compilation = config_proto.release_compilation_handle(); - - string rendezvous_key_base = std::to_string(compilation_handle); - std::unique_ptr entry; - TF_RETURN_IF_ERROR(GetComputationCacheEntry(rm, compilation_handle, - core_index_in_replica, &entry)); - - TpuCompilationCacheEntry centry = entry->get(); - const tpu::TpuProgramGroup* tpu_program_group = - tensorflow::down_cast( - centry.tpu_program_group()); - CHECK_NE(tpu_program_group, nullptr); - - if (release_compilation) { - // Process-wide cache of Tpu executables. - tpu::TpuCompilationCacheInterface* cache; - TF_RETURN_IF_ERROR(rm->Lookup( - rm->default_container(), tpu::kCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); - TF_RETURN_IF_ERROR(cache->Release(compilation_handle)); - VLOG(2) << "Released compilation handle " << compilation_handle; - } - - const int core_index = centry.core_index(); - const TPUExecutableInfoProto& executable = - tpu_program_group->executable_info(core_index); - - std::vector input_is_dynamic = GetDynamicInputInfo(executable); - - TF_ASSIGN_OR_RETURN( - xla::HloInputOutputAliasConfig input_output_alias, - GetExecutableAliasConfig(tpu_program_group, backend, core_index)); - TF_ASSIGN_OR_RETURN(std::vector input_coords, - GetComputationInputs(context, "input_handles")); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - XRTMemoryManager::WorkingSet working_set(memory_manager); - TF_ASSIGN_OR_RETURN( - std::vector> input_tuples, - GetInputTupleAllocations( - input_coords, &working_set, backend, executable.input_shapes_size(), - [&](int64_t i) { return xla::Shape(executable.input_shapes(i)); }, - release_inputs, backend->memory_allocator())); - auto get_buffers_fn = [&]() { - return GetArgumentsBuffers(input_output_alias, input_tuples, - input_is_dynamic, release_inputs); - }; - trace_me_init.Stop(); - - TF_ASSIGN_OR_RETURN( - xla::ExecutionOutput output, - ExecuteTPUProgram( - context, node_context.get(), memory_manager.get(), executable, - get_buffers_fn, config_proto.execution_instance_key(), - config_proto.rng_seed(), tpu_program_group, backend, stream, - core_index, device_ordinal, rendezvous_key_base)); - - // AllocateComputationOutput writes the output tuple handle to the output - // tensor return value from the Op. - TF_RETURN_IF_ERROR(AllocateOutputTensors( - context, memory_manager.get(), node_context.get(), stream, config_proto, - executable, input_tuples, input_output_alias, output.ConsumeResult(), - device_ordinal)); - return OkStatus(); -} - -class XRTExecuteChainedOp : public AsyncOpKernel { - public: - explicit XRTExecuteChainedOp(OpKernelConstruction* context); - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override; - - private: - Status DoWork(OpKernelContext* context); -}; - -XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context) - : AsyncOpKernel(context, /* is_deferred = */ true) {} - -void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context, - DoneCallback done) { - // Schedule onto the default queue, for unbounded concurrency. See b/73520706 - OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); - done(); -} - -Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { - VLOG(1) << "XRTExecuteChainedOp::Compute"; - const XlaDevice::Metadata* metadata; - TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata)); - const int device_ordinal = metadata->device_ordinal(); - // We are guaranteed that the object underlying TpuNodeContext won't be - // deleted out from under us, while node_context is alive. - TF_ASSIGN_OR_RETURN(std::unique_ptr node_context, - tpu::TpuNodeContext::Create(device_ordinal)); - xla::Backend* const backend = node_context->backend(); - se::Stream* stream = context->op_device_context()->stream(); - auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell()); - profiler::TraceMe trace_me( - [context] { - return profiler::TraceMeEncode("TpuExecuteChainedOp", - {{"step_id", context->step_id()}}); - }, - /*level=*/2); - ResourceMgr* rm = GetTPUConfigResourceMgr(); - TF_RET_CHECK(rm != nullptr); - - const Tensor& execution_plan = context->input(0); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape())); - xrt::XRTChainedExecutePlan plan; - TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); - - const Tensor& execution_config = context->input(1); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); - xrt::XRTChainedExecuteConfig config; - TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); - - TpuCompilationCacheLookup* proto_lookup; - TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(), - tpu::kCompiledProtoCacheResourceName, - &proto_lookup)); - core::ScopedUnref lookup_unref(proto_lookup); - RefPtr memory_manager = XRTMemoryManager::Get(rm); - auto execute_op = [&](const xrt::XRTChainedExecuteOp& op, - absl::Span> op_inputs) - -> xla::StatusOr> { - std::unique_ptr entry; - TF_RETURN_IF_ERROR(proto_lookup->Lookup( - op.computation_handle(), config.core_index_in_replica(), &entry)); - string rendezvous_key_base = std::to_string(op.computation_handle()); - TpuCompilationCacheEntry centry = entry->get(); - const tpu::TpuProgramGroup* tpu_program_group = - tensorflow::down_cast( - centry.tpu_program_group()); - CHECK_NE(tpu_program_group, nullptr); - const int core_index = centry.core_index(); - const TPUExecutableInfoProto& executable = - tpu_program_group->executable_info(core_index); - std::vector input_is_dynamic = GetDynamicInputInfo(executable); - - TF_ASSIGN_OR_RETURN( - xla::HloInputOutputAliasConfig input_output_alias, - GetExecutableAliasConfig(tpu_program_group, backend, core_index)); - TF_ASSIGN_OR_RETURN(std::vector> input_tuples, - GetChainedOpInputs(op, op_inputs, executable)); - auto get_buffers_fn = [&]() { - return GetArgumentsBuffers(input_output_alias, input_tuples, - input_is_dynamic, - /*release_inputs=*/false); - }; - TF_ASSIGN_OR_RETURN( - xla::ExecutionOutput output, - ExecuteTPUProgram(context, node_context.get(), memory_manager.get(), - executable, get_buffers_fn, - config.execution_instance_key(), config.rng_seed(), - tpu_program_group, backend, stream, core_index, - device_ordinal, rendezvous_key_base)); - return AllocateOutputTuple(node_context.get(), stream, input_tuples, - input_output_alias, output.ConsumeResult(), - device_ordinal); - }; - - return ExecuteChained(context, memory_manager, backend, device_ordinal, plan, - config, execute_op, backend->memory_allocator()); -} - -} // namespace - -REGISTER_KERNEL_BUILDER(Name("XRTExecute") - .Device(DEVICE_TPU_NODE) - .HostMemory("computation_handle") - .HostMemory("execution_config") - .HostMemory("input_handles") - .HostMemory("output_handle"), - XRTExecuteOp); - -REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained") - .Device(DEVICE_TPU_NODE) - .HostMemory("execution_plan") - .HostMemory("execution_config") - .HostMemory("output_handle"), - XRTExecuteChainedOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/tpu_state_op.cc b/tensorflow/compiler/xrt/kernels/tpu_state_op.cc deleted file mode 100644 index 6fe1321c413887..00000000000000 --- a/tensorflow/compiler/xrt/kernels/tpu_state_op.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for allocating XLA literals in device memory and managing handles -// that refer to them. - -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/local_client.h" -#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt_tpu_device.h" -#include "tensorflow/core/tpu/tpu_defs.h" - -namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("XRTAllocate") - .Device(DEVICE_TPU_NODE) - .HostMemory("allocation") - .HostMemory("handle"), - XRTAllocateOp); - -REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") - .Device(DEVICE_TPU_NODE) - .HostMemory("handle"), - XRTAllocateUninitializedOp); - -REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") - .Device(DEVICE_TPU_NODE) - .HostMemory("inputs") - .HostMemory("handle"), - XRTAllocateFromTensorOp); - -REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") - .Device(DEVICE_TPU_NODE) - .HostMemory("base_handle") - .HostMemory("shape_index") - .HostMemory("output_handle"), - XRTSubTupleOp); - -REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") - .Device(DEVICE_TPU_NODE) - .HostMemory("base_handle") - .HostMemory("shape_index") - .HostMemory("output_handle"), - XRTSubTupleOp); - -REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") - .Device(DEVICE_TPU_NODE) - .HostMemory("tuple_description") - .HostMemory("input_handles") - .HostMemory("output_handle"), - XRTMakeTupleOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") - .Device(DEVICE_TPU_NODE) - .HostMemory("handle") - .HostMemory("literal"), - XRTReadLiteralOp); - -REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") - .Device(DEVICE_TPU_NODE) - .HostMemory("handle") - .HostMemory("literal") - .HostMemory("output_handle"), - XRTWriteLiteralOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") - .Device(DEVICE_TPU_NODE) - .HostMemory("handle") - .HostMemory("literal"), - XRTReadLiteralOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") - .Device(DEVICE_TPU_NODE) - .HostMemory("handles") - .HostMemory("tensors"), - XRTReadToTensorOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") - .Device(DEVICE_TPU_NODE) - .HostMemory("handle"), - XRTReleaseAllocationOp); - -REGISTER_KERNEL_BUILDER( - Name("XRTReleaseAllAllocations").Device(DEVICE_TPU_NODE), - XRTReleaseAllAllocationsOp); - -REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_TPU_NODE), - XRTCompactAllocationsOp); - -REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_TPU_NODE), - XRTMemoryInfoOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc deleted file mode 100644 index ec6a9c56dfbdab..00000000000000 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ /dev/null @@ -1,301 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for compiling XLA computations and managing handles that refer to -// them. - -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/client_library.h" -#include "xla/client/xla_computation.h" -#include "xla/service/compiler.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" -#include "tensorflow/compiler/xrt/xrt_device.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" -#include "tensorflow/compiler/xrt/xrt_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/monitoring/timed.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -namespace { - -Status GenerateXlaDeviceAssignment( - const xrt::DeviceAssignment& xrt_device_assignment, int num_replicas, - int num_cores_per_replica, xla::DeviceAssignment* device_assignment) { - if (num_cores_per_replica != - xrt_device_assignment.computation_devices_size()) { - return errors::InvalidArgument( - "Device assignment does not have the correct number of " - "computation_devices: num_cores_per_replica=", - num_cores_per_replica, " computation_devices=", - xrt_device_assignment.computation_devices_size()); - } - for (int64_t c = 0; c < xrt_device_assignment.computation_devices_size(); - ++c) { - const auto& computation_devices = - xrt_device_assignment.computation_devices(c); - if (num_replicas != computation_devices.replica_devices_size()) { - return errors::InvalidArgument( - "Device assignment does not have the correct number of " - "replica_device_ids: num_replicas=", - num_replicas, - " replica_devices=", computation_devices.replica_devices_size()); - } - for (int64_t r = 0; r < computation_devices.replica_devices_size(); ++r) { - const auto& coords = computation_devices.replica_devices(r); - if (coords.value_size() != 4) { - return errors::InvalidArgument( - "Device assignment mesh coordinates must have 4 entries, got ", - coords.value_size()); - } - for (int n = 0; n < 3; ++n) { - if (coords.value(n) != 0) { - return errors::InvalidArgument("Mesh coordinate at index ", n, - " must be 0, got ", coords.value(n)); - } - } - (*device_assignment)(r, c) = coords.value(3); - } - } - return OkStatus(); -} - -class XRTCompileOp : public OpKernel { - public: - explicit XRTCompileOp(OpKernelConstruction* ctx); - ~XRTCompileOp() override; - XRTCompileOp(const XRTCompileOp&) = delete; - XRTCompileOp& operator=(const XRTCompileOp&) = delete; - - void Compute(OpKernelContext* ctx) override; - - private: - Status Compile(OpKernelContext* ctx, - const xrt::XLAComputation& computation_proto, - std::unique_ptr* program); -}; - -Status CompilationCacheKey(const xrt::XLAComputation& computation, - string* key) { - const size_t size = computation.ByteSizeLong(); - auto serialized = absl::make_unique(size); - TF_RET_CHECK( - SerializeToBufferDeterministic(computation, serialized.get(), size)); - uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); - *key = absl::StrCat(fingerprint); - return OkStatus(); -} - -XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - -Status XRTCompileOp::Compile(OpKernelContext* ctx, - const xrt::XLAComputation& computation_proto, - std::unique_ptr* program) { - const xrt::XLAComputationConfig& config = computation_proto.config(); - // Sanity checks for options not yet supported. - int num_cores_per_replica = std::max(config.num_cores_per_replica(), 1); - TF_RET_CHECK(num_cores_per_replica == 1); - TF_RET_CHECK(config.per_core_program_shape_size() == 0); - - // The default config value is 0; treat it as 1 for convenience. - int num_replicas = config.num_replicas() ? config.num_replicas() : 1; - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class XRTGenericDeviceAccessor::ScopedRef device_ref; - TF_RETURN_IF_ERROR(XRTGenericDeviceAccessor::InitScopedRef(ctx, &device_ref)); - - xla::LocalClient* client = device_ref.client(); - - // There is officially no way to use XLA in a client/server architecture where - // client and server are built from different revisions, because the XLA team - // does not want to give any guarantees about the stability of the Hlo - // proto. For cloud TPU this is fine because server and client versions can be - // assumed to be synced to the same version. For general use the mechanism - // here (using a snapshot from XlaComputation) works as well as the "official" - // XLA client/server design, which serializes the same proto between client - // and server, so in reality is probably fine. - TF_ASSIGN_OR_RETURN(xla::XlaComputation computation, - client->LoadSnapshot(computation_proto.hlo_snapshot())); - - std::vector argument_layouts( - config.program_shape().parameters_size()); - std::vector argument_layout_ptrs( - config.program_shape().parameters_size()); - for (int i = 0; i < config.program_shape().parameters_size(); ++i) { - argument_layouts[i] = xla::Shape(config.program_shape().parameters(i)); - argument_layout_ptrs[i] = &argument_layouts[i]; - } - xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(device_ref.device_ordinal()); - build_options.set_num_replicas(num_replicas); - build_options.set_result_layout(xla::Shape(config.program_shape().result())); - build_options.set_device_allocator(device_ref.allocator()); - if (config.has_debug_options()) { - *build_options.mutable_debug_options() = - BuildXlaDebugOptions(config.debug_options()); - } - if (config.has_device_assignment()) { - xla::DeviceAssignment device_assignment(num_replicas, - num_cores_per_replica); - TF_RETURN_IF_ERROR( - GenerateXlaDeviceAssignment(config.device_assignment(), num_replicas, - num_cores_per_replica, &device_assignment)); - build_options.set_device_assignment(device_assignment); - } - - VLOG(1) << "Building executable"; - TF_ASSIGN_OR_RETURN( - auto executables, - client->Compile(computation, argument_layout_ptrs, build_options)); - TF_RET_CHECK(executables.size() == 1); - *program = std::move(executables[0]); - return OkStatus(); -} - -void XRTCompileOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XRTCompileOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetCompileCell()); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); - - const Tensor& computation_input = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()), - errors::Internal("computation input should be a string scalar")); - - xrt::XLAComputation computation_proto; - OP_REQUIRES(ctx, - ParseFromTString(computation_input.scalar()(), - &computation_proto), - errors::InvalidArgument( - "Unable to parse computation input to XLAComputation")); - - string key; - OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key)); - - // Process-wide cache of XLA executables. - auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( - ctx, /*max_number_of_entries=*/0); - OP_REQUIRES_OK(ctx, cache_or.status()); - auto cache = std::move(cache_or).value(); - - int64_t uid; - OP_REQUIRES_OK( - ctx, cache->CompileIfKeyAbsent( - key, &uid, [&](std::unique_ptr* program) { - VLOG(1) << "Compiling XLA executable"; - return Compile(ctx, computation_proto, program); - })); - std::unique_ptr entry; - OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry)); - - Tensor handle_output(DT_INT64, TensorShape({})); - handle_output.scalar()() = uid; - ctx->set_output(0, handle_output); - - xla::LocalExecutable* executable = entry->get().get_executable(); - xla::ProgramShapeProto program_shape = executable->executable() - ->module() - .config() - .entry_computation_layout() - .ComputeProgramShape() - .ToProto(); - Tensor program_shape_output(DT_STRING, TensorShape({1})); - program_shape_output.vec()(0) = program_shape.SerializeAsString(); - ctx->set_output(1, program_shape_output); -} - -XRTCompileOp::~XRTCompileOp() = default; - -class XRTReleaseCompilationRefOp : public OpKernel { - public: - explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx); - ~XRTReleaseCompilationRefOp() override; - XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete; - XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) = - delete; - - void Compute(OpKernelContext* ctx) override; -}; - -XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp( - OpKernelConstruction* ctx) - : OpKernel(ctx) {} - -XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default; - -void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell()); - - // Process-wide cache of XLA executables. - auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( - ctx, /*max_number_of_entries=*/0); - OP_REQUIRES_OK(ctx, cache_or.status()); - auto cache = std::move(cache_or).value(); - - const Tensor& keys_tensor = ctx->input(0); - auto flat_keys = keys_tensor.flat(); - for (int64_t i = 0; i < flat_keys.size(); ++i) { - int64_t key = flat_keys(i); - OP_REQUIRES_OK(ctx, cache->Release(key)); - VLOG(2) << "Released computation handle " << key; - } -} - -} // namespace - -REGISTER_KERNEL_BUILDER(Name("XRTCompile") - .Device(DEVICE_XLA_CPU) - .HostMemory("computation") - .HostMemory("handle"), - XRTCompileOp); -REGISTER_KERNEL_BUILDER(Name("XRTCompile") - .Device(DEVICE_XLA_GPU) - .HostMemory("computation") - .HostMemory("handle"), - XRTCompileOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle") - .Device(DEVICE_XLA_CPU) - .HostMemory("handle"), - XRTReleaseCompilationRefOp); -REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle") - .Device(DEVICE_XLA_GPU) - .HostMemory("handle"), - XRTReleaseCompilationRefOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc deleted file mode 100644 index 47c2fa2f2b92c6..00000000000000 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ /dev/null @@ -1,618 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/literal_util.h" -#include "xla/service/computation_placer.h" -#include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_internal.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" -#include "tensorflow/compiler/xrt/xrt_device.h" -#include "tensorflow/compiler/xrt/xrt_memory_manager.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" -#include "tensorflow/compiler/xrt/xrt_state.h" -#include "tensorflow/compiler/xrt/xrt_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/monitoring/timed.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -namespace { - -uint32 InitialRandomSeed() { - // Support plumbing the TF seed through to XLA is being worked on. - // If a user wants deterministic behavior, their best option - // is to start with a known checkpoint. This also handles issues when - // multiple random calls can be invoked in any order by TF executor. - // Another option is to use stateless random ops. They have much cleaner - // semantics. - // If a user really wants to set a deterministic seed for XLA-based - // devices, this is the place to do it. - std::random_device rd; - // Make the starting value odd. - return rd() | 1; -} - -uint32 GetXLARandomSeed() { - // We initialize counter with an odd number and increment it by two - // everytime. This ensures that it will never be zero, even - // after an overflow. When seeded with zero, some XLA backends - // can return all zeros instead of random numbers. - static std::atomic counter(InitialRandomSeed()); - return counter.fetch_add(2); -} - -std::vector GetDynamicInputInfo( - const xla::ComputationLayout& computation_layout) { - std::vector input_is_dynamic; - input_is_dynamic.reserve(computation_layout.parameter_count()); - for (int64_t i = 0; i < computation_layout.parameter_count(); ++i) { - input_is_dynamic.push_back( - !computation_layout.parameter_shape(i).is_static()); - } - return input_is_dynamic; -} - -xla::StatusOr>> GetInputTuples( - xla::LocalExecutable* executable, XRTMemoryManager::WorkingSet* working_set, - xla::Backend* backend, const std::vector& input_coords, - bool release_inputs, se::DeviceMemoryAllocator* allocator) { - const xla::ComputationLayout& computation_layout = - executable->executable()->module_config().entry_computation_layout(); - - return GetInputTupleAllocations( - input_coords, working_set, backend, computation_layout.parameter_count(), - [&](int64_t i) { return computation_layout.parameter_shape(i); }, - release_inputs, allocator); -} - -xla::StatusOr>> GetChainedOpInputTuples( - const xrt::XRTChainedExecuteOp& op, - absl::Span> op_inputs) { - std::vector> input_tuples; - input_tuples.reserve(op.inputs_size()); - for (int i = 0; i < op.inputs_size(); ++i) { - auto& input = op.inputs(i); - // Thanks to the greatness of proto3, there is no way to query for - // explicitly set fields, so the default for output_index (zero) means no - // sub-index. As consequence, the real index is output_index - 1. - if (input.output_index() == 0) { - input_tuples.emplace_back(op_inputs[i]); - } else { - XRTTupleAllocation* sub_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - op_inputs[i].get(), {input.output_index() - 1}, &sub_tuple, - /*alias_parent_allocation=*/true)); - input_tuples.emplace_back(sub_tuple); - } - } - return input_tuples; -} - -// Given a shape, returns a byte array representing the shape metadata of the -// shape. The shape metadata contains dimensions sizes stored as contiguous S32. -std::vector PrepareMetadata(const xla::Shape& shape) { - DCHECK(shape.is_static()); - DCHECK(shape.IsArray()); - // Each dimension size is stored as a S32. - std::vector result(shape.dimensions_size()); - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { - result[i] = shape.dimensions(i); - } - return result; -} - -// Given a buffer with dynamic shape, update buffer metadata at the correct -// offset starting from that buffer. -// -// +-----------+ -// |Payload | -// +-----------+ -// | Padding | -// +-----------+ -// |dim_size_0 | (each dim_size is a S32): -// +-----------+ -// |dim_size_1 | -// +-----------+ -// .......... -// +-----------+ -// -// Size of payload = ByteSizeOf(runtime_shape) -// Size of payload + padding = ByteSizeOf(compile_time_shape_static) -// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape) -Status UpdateMetadata(se::Stream* stream, se::DeviceMemory* buffer, - const xla::Shape& compile_time_shape, - const xla::Shape& runtime_shape) { - TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( - stream->parent()->platform())); - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - xla::TransferManager::GetForPlatform(stream->parent()->platform())); - auto shape_size_fn = compiler->ShapeSizeBytesFunction(); - xla::Shape compile_time_shape_static = - xla::ShapeUtil::MakeStaticShape(compile_time_shape); - uint64 offset = shape_size_fn(compile_time_shape_static); - uint64 metadata_size = shape_size_fn(compile_time_shape) - offset; - auto metadata_buffer = - stream->parent()->GetSubBuffer(buffer, offset, metadata_size); - - auto metadata_literal = std::make_shared( - xla::LiteralUtil::CreateR1(PrepareMetadata(runtime_shape))); - TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync( - stream, *metadata_literal, metadata_buffer)); - // Retain the literal until the end of the transfer. - stream->ThenDoHostCallback([keep_alive = std::move(metadata_literal)] {}); - return OkStatus(); -} - -// Given a static input buffer, convert it to dynamic form by expanding it to -// the bounded size and attaching a metadata filled with dimension sizes. -// -// From: -// +--------+ -// |Payload | -// +--------+ -// -// To: -// -// +--------+ -// |Payload | -// +--------+ -// | Padding| -// +--------+ -// |Metadata| -// +--------+ -// -// As we can't expand the size of an existing memory allocation, a reallocation -// is required. A list of new allocations are returned after this function. The -// caller is reponsible for maintaining those allocations. -Status UpdateDynamicInputs( - se::Stream* stream, se::DeviceMemoryAllocator* allocator, - std::vector* execution_inputs, - const std::vector& compile_time_shapes) { - TF_RET_CHECK(execution_inputs->size() == compile_time_shapes.size()); - TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( - stream->parent()->platform())); - auto shape_size_fn = compiler->ShapeSizeBytesFunction(); - for (int64_t i = 0; i < compile_time_shapes.size(); i++) { - const xla::Shape& compile_time_shape = compile_time_shapes[i].shape(); - if (compile_time_shape.is_static()) { - continue; - } - xla::ExecutionInput* execution_input = &(*execution_inputs)[i]; - bool element_modified = false; - TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( - compile_time_shape, - [&](const xla::Shape& sub_shape, - const xla::ShapeIndex& index) -> Status { - if (sub_shape.IsTuple() || sub_shape.is_static()) { - return OkStatus(); - } - TF_ASSIGN_OR_RETURN( - const xla::Shape* runtime_shape, - xla::ShapeUtil::TryGetSubshape(execution_input->shape(), index)); - TF_RET_CHECK(!runtime_shape->IsTuple()); - TF_RET_CHECK(xla::ShapeUtil::DynamicArrayShapeIsCompatible( - *runtime_shape, sub_shape)); - TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory dynamic_input, - allocator->Allocate(stream->parent()->device_ordinal(), - shape_size_fn(sub_shape))); - - se::DeviceMemoryBase static_input = - execution_input->Buffer(index).AsDeviceMemoryBase(); - se::DeviceMemory* dynamic_input_base = dynamic_input.ptr(); - // Send the original data to the new location. - stream->ThenMemcpyD2D(dynamic_input_base, static_input, - static_input.size()); - TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, - sub_shape, *runtime_shape)); - // Modify the memory location in the input shape tree to point to the - // new input. - execution_input->SetBuffer( - index, xla::MaybeOwningDeviceMemory(std::move(dynamic_input))); - execution_input->ClearUnownedIndex(index); - element_modified = true; - return OkStatus(); - })); - if (element_modified) { - TF_RETURN_IF_ERROR(execution_input->SetDynamicShape(compile_time_shape)); - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, - execution_input->ToShapedBuffer( - allocator, stream->parent()->device_ordinal())); - // The input location has been modified, need to fix tuple table to - // point to the correct address. - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - xla::TransferManager::GetForPlatform(stream->parent()->platform())); - TF_RETURN_IF_ERROR( - transfer_manager->WriteTupleIndexTablesAsync(stream, shaped_buffer)); - } - } - return OkStatus(); -} - -xla::StatusOr> CreateOutputTuple( - se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend, - int device_ordinal, se::DeviceMemoryAllocator* allocator) { - XRTTupleAllocation* output_tuple; - xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult(); - if (shaped_buffer->on_device_shape().is_dynamic()) { - // Update dynamic shapes from output buffer, and create a XRT tensor with - // dimension sizes read from metadata. - xla::Shape output_device_shape = shaped_buffer->on_device_shape(); - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - xla::TransferManager::GetForPlatform(stream->parent()->platform())); - TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( - stream, shaped_buffer, &output_device_shape)); - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - *shaped_buffer, - xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape), - output_device_shape, backend, device_ordinal, &output_tuple, - allocator)); - } else { - // Fast-path: Don't copy shapes of output buffer. - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - *shaped_buffer, backend, device_ordinal, &output_tuple, allocator)); - } - // After the output tuple is created, we can release the output result - // buffers, to make sure they won't be cleared by its destructor. - (void)run_result.ConsumeResult().release(); - return RefPtr(output_tuple); -} - -xla::StatusOr> RunExecutable( - OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, - xla::LocalExecutable* executable, - absl::Span> input_tuples, - bool release_inputs, se::Stream* stream, int rng_seed, - const xrt::CommonExecutionConfig& config) { - const xla::ComputationLayout& computation_layout = - executable->executable()->module_config().entry_computation_layout(); - std::vector input_is_dynamic = GetDynamicInputInfo(computation_layout); - TF_ASSIGN_OR_RETURN( - std::vector execution_inputs, - GetArgumentsBuffers( - executable->executable()->module().input_output_alias_config(), - input_tuples, input_is_dynamic, release_inputs)); - - se::DeviceMemoryAllocator* allocator = device_ref->allocator(); - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(allocator); - run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); - run_options.set_rng_seed(rng_seed); - if (config.run_id() != 0) { - run_options.set_run_id(xla::RunId(config.run_id())); - } - if (executable->executable() - ->module_config() - .has_static_device_assignment()) { - run_options.set_device_assignment( - &executable->executable()->module_config().static_device_assignment()); - } - xla::gpu::GpuExecutableRunOptions gpu_options; - std::map gpu_global_ids; - if (config.local_replica_mapping_size() > 0) { - int i = 0; - for (auto& gid : config.local_replica_mapping()) { - gpu_global_ids[i++] = xla::GlobalDeviceId(gid); - } - gpu_options.set_gpu_global_device_ids(gpu_global_ids); - } - std::shared_ptr nccl_factory = GetNcclUniqueIdFactory(); - if (nccl_factory != nullptr) { - auto uid_callback = - [&](const xla::gpu::NcclCliqueKey& key) -> xla::StatusOr { - std::vector replicas; - const auto key_devices = key.devices(); - replicas.reserve(key_devices.size()); - for (auto& device : key_devices) { - replicas.push_back(device.value()); - } - return nccl_factory->GetUniqueId(replicas); - }; - gpu_options.set_nccl_unique_id_callback(uid_callback); - } - run_options.set_gpu_executable_run_options(&gpu_options); - - const std::vector& shape_layouts = - executable->executable() - ->module_config() - .entry_computation_layout() - .parameter_layouts(); - TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, run_options.allocator(), - &execution_inputs, shape_layouts)); - TF_ASSIGN_OR_RETURN( - xla::ExecutionOutput run_result, - executable->Run(std::move(execution_inputs), run_options)); - - TF_ASSIGN_OR_RETURN( - RefPtr output_tuple_ptr, - CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), - device_ref->device_ordinal(), allocator)); - // The ScopedShapedBuffer returned by the executable Run() API, in case of - // input/output buffer aliasing, might have holes in it, which need to be - // filled using the proper input tuples buffers which are the source of - // aliasing. - TF_RETURN_IF_ERROR(RebuildOutputAliases( - output_tuple_ptr, input_tuples, - executable->executable()->module().input_output_alias_config())); - - return std::move(output_tuple_ptr); -} - -xla::StatusOr> ExecuteComputation( - OpKernelContext* context, XRTMemoryManager* memory_manager, - XRTGenericDeviceAccessor::ScopedRef* device_ref, - xla::LocalExecutable* executable, - absl::Span> input_tuples, - bool release_inputs, se::Stream* stream, int rng_seed, - const xrt::CommonExecutionConfig& config) { - auto runfn = [&]() { - return RunExecutable(context, device_ref, executable, input_tuples, - release_inputs, stream, rng_seed, config); - }; - - // We pass zero as requested_free_size as there is no simple way to get the - // peak heap size. Upon zero, the Run() API will try to free chunks of device - // memory, until either the runfn can run, or we run out of freeable memory. - return memory_manager->Run>( - runfn, device_ref->backend(), device_ref->device_ordinal(), - /*requested_free_size=*/0, device_ref->allocator()); -} - -xla::StatusOr> ExecuteComputation( - OpKernelContext* context, const RefPtr& memory_manager, - XRTGenericDeviceAccessor::ScopedRef* device_ref, - xla::LocalExecutable* executable, - const std::vector& input_coords, bool release_inputs, - se::Stream* stream, int rng_seed, - const xrt::CommonExecutionConfig& config) { - XRTMemoryManager::WorkingSet working_set(memory_manager); - TF_ASSIGN_OR_RETURN( - std::vector> input_tuples, - GetInputTuples(executable, &working_set, device_ref->backend(), - input_coords, release_inputs, device_ref->allocator())); - return ExecuteComputation(context, memory_manager.get(), device_ref, - executable, input_tuples, release_inputs, stream, - rng_seed, config); -} - -// XRTExecuteOp - -class XRTExecuteOp : public AsyncOpKernel { - public: - explicit XRTExecuteOp(OpKernelConstruction* context); - ~XRTExecuteOp() override; - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override; - - private: - Status DoWork(OpKernelContext* context); -}; - -XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - -void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) { - // Schedule onto the default queue, for unbounded concurrency. See b/73520706 - Env::Default()->SchedClosure([this, context, done]() { - OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); - done(); - }); -} - -Status XRTExecuteOp::DoWork(OpKernelContext* context) { - VLOG(1) << "XRTExecuteOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell()); - ResourceMgr* rm; - TF_RETURN_IF_ERROR( - XRTGenericDeviceAccessor::GetResourceManager(context, &rm)); - - const Tensor& execution_input = context->input(0); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape())); - int64_t compilation_handle = execution_input.scalar()(); - - const Tensor& execution_config = context->input(1); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); - xrt::XRTExecutionConfig config_proto; - TF_RET_CHECK( - ParseFromTString(execution_config.scalar()(), &config_proto)); - - int core_index_in_replica = config_proto.core_index_in_replica(); - TF_RET_CHECK(core_index_in_replica == 0); - bool release_inputs = config_proto.release_input_handles(); - bool release_compilation = config_proto.release_compilation_handle(); - - TF_ASSIGN_OR_RETURN(auto cache, - XRTGenericDeviceAccessor::GetOrCreateCompilationCache( - context, /*max_number_of_entries=*/0)); - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class XRTGenericDeviceAccessor::ScopedRef device_ref; - TF_RETURN_IF_ERROR( - XRTGenericDeviceAccessor::InitScopedRef(context, &device_ref)); - - int rng_seed = config_proto.rng_seed(); - if (rng_seed == 0) { - rng_seed = GetXLARandomSeed(); - } - - se::Stream* stream = context->op_device_context() - ? context->op_device_context()->stream() - : nullptr; - RefPtr memory_manager = XRTMemoryManager::Get(rm); - TF_ASSIGN_OR_RETURN(std::vector input_coords, - GetComputationInputs(context, "input_handles")); - - std::unique_ptr entry; - TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry)); - xla::LocalExecutable* executable = entry->get().get_executable(); - if (release_compilation) { - // Process-wide cache of XLA executables. - TF_RETURN_IF_ERROR(cache->Release(compilation_handle)); - VLOG(2) << "Released compilation handle " << compilation_handle; - } - - TF_ASSIGN_OR_RETURN( - RefPtr output_tuple, - ExecuteComputation(context, memory_manager, &device_ref, executable, - input_coords, release_inputs, stream, rng_seed, - config_proto.common_config())); - - return CreateExecuteOutput(context, memory_manager.get(), - std::move(output_tuple), - config_proto.return_exploded_tuple()); -} - -XRTExecuteOp::~XRTExecuteOp() = default; - -class XRTExecuteChainedOp : public AsyncOpKernel { - public: - explicit XRTExecuteChainedOp(OpKernelConstruction* context); - ~XRTExecuteChainedOp() override; - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override; - - private: - Status DoWork(OpKernelContext* context); -}; - -XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - -void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context, - DoneCallback done) { - // Schedule onto the default queue, for unbounded concurrency. See b/73520706 - Env::Default()->SchedClosure([this, context, done]() { - OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); - done(); - }); -} - -Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { - VLOG(1) << "XRTExecuteChainedOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell()); - ResourceMgr* rm; - TF_RETURN_IF_ERROR( - XRTGenericDeviceAccessor::GetResourceManager(context, &rm)); - - const Tensor& execution_plan = context->input(0); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape())); - xrt::XRTChainedExecutePlan plan; - TF_RET_CHECK(ParseFromTString(execution_plan.scalar()(), &plan)); - - const Tensor& execution_config = context->input(1); - TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); - xrt::XRTChainedExecuteConfig config; - TF_RET_CHECK(ParseFromTString(execution_config.scalar()(), &config)); - - TF_ASSIGN_OR_RETURN(auto cache, - XRTGenericDeviceAccessor::GetOrCreateCompilationCache( - context, /*max_number_of_entries=*/0)); - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class XRTGenericDeviceAccessor::ScopedRef device_ref; - TF_RETURN_IF_ERROR( - XRTGenericDeviceAccessor::InitScopedRef(context, &device_ref)); - - int rng_seed = config.rng_seed(); - if (rng_seed == 0) { - rng_seed = GetXLARandomSeed(); - } - - se::Stream* stream = context->op_device_context() - ? context->op_device_context()->stream() - : nullptr; - RefPtr memory_manager = XRTMemoryManager::Get(rm); - auto execute_op = [&](const xrt::XRTChainedExecuteOp& op, - absl::Span> op_inputs) - -> xla::StatusOr> { - std::unique_ptr entry; - TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry)); - xla::LocalExecutable* executable = entry->get().get_executable(); - - TF_ASSIGN_OR_RETURN(std::vector> input_tuples, - GetChainedOpInputTuples(op, op_inputs)); - - return ExecuteComputation( - context, memory_manager.get(), &device_ref, executable, input_tuples, - /*release_inputs=*/false, stream, rng_seed, config.common_config()); - }; - - return ExecuteChained(context, memory_manager, device_ref.backend(), - device_ref.device_ordinal(), plan, config, execute_op, - device_ref.allocator()); -} - -XRTExecuteChainedOp::~XRTExecuteChainedOp() = default; - -} // namespace - -REGISTER_KERNEL_BUILDER(Name("XRTExecute") - .Device(DEVICE_XLA_CPU) - .HostMemory("computation_handle") - .HostMemory("execution_config") - .HostMemory("input_handles") - .HostMemory("output_handle"), - XRTExecuteOp); - -REGISTER_KERNEL_BUILDER(Name("XRTExecute") - .Device(DEVICE_XLA_GPU) - .HostMemory("computation_handle") - .HostMemory("execution_config") - .HostMemory("input_handles") - .HostMemory("output_handle"), - XRTExecuteOp); - -REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained") - .Device(DEVICE_XLA_CPU) - .HostMemory("execution_plan") - .HostMemory("execution_config") - .HostMemory("output_handle"), - XRTExecuteChainedOp); - -REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained") - .Device(DEVICE_XLA_GPU) - .HostMemory("execution_plan") - .HostMemory("execution_config") - .HostMemory("output_handle"), - XRTExecuteChainedOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc deleted file mode 100644 index 09ca1ef948aaf1..00000000000000 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for allocating XLA literals in device memory and managing handles -// that refer to them. - -#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h" - -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/local_client.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" - -namespace tensorflow { -namespace { - -class XRTMetricsCollectOp : public OpKernel { - public: - explicit XRTMetricsCollectOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTMetricsCollectOp::Compute"; - - const Tensor& metrics_proto = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(metrics_proto.shape()), - errors::Internal("request input should be a string scalar")); - xrt::XRTMetricsCollect metrics; - OP_REQUIRES(ctx, - ParseFromTString(metrics_proto.scalar()(), &metrics), - errors::InvalidArgument( - "Unable to parse request input to XRTMetricsCollect")); - - xla::StatusOr collected_metrics_or = - CollectMetrics(metrics); - OP_REQUIRES_OK(ctx, collected_metrics_or.status()); - xrt::MetricsReport collected_metrics = - std::move(collected_metrics_or).value(); - Tensor output(DT_STRING, TensorShape({})); - output.scalar()() = collected_metrics.SerializeAsString(); - ctx->set_output(0, output); - } -}; - -} // namespace - -REGISTER_KERNEL_BUILDER(Name("XRTAllocate") - .Device(DEVICE_XLA_GPU) - .HostMemory("allocation") - .HostMemory("handle"), - XRTAllocateOp); -REGISTER_KERNEL_BUILDER(Name("XRTAllocate") - .Device(DEVICE_XLA_CPU) - .HostMemory("allocation") - .HostMemory("handle"), - XRTAllocateOp); - -REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") - .Device(DEVICE_XLA_GPU) - .HostMemory("handle"), - XRTAllocateUninitializedOp); -REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") - .Device(DEVICE_XLA_CPU) - .HostMemory("handle"), - XRTAllocateUninitializedOp); - -REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") - .Device(DEVICE_XLA_GPU) - .HostMemory("inputs") - .HostMemory("handle"), - XRTAllocateFromTensorOp); -REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") - .Device(DEVICE_XLA_CPU) - .HostMemory("inputs") - .HostMemory("handle"), - XRTAllocateFromTensorOp); - -REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") - .Device(DEVICE_XLA_GPU) - .HostMemory("base_handle") - .HostMemory("shape_index") - .HostMemory("output_handle"), - XRTSubTupleOp); -REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") - .Device(DEVICE_XLA_CPU) - .HostMemory("base_handle") - .HostMemory("shape_index") - .HostMemory("output_handle"), - XRTSubTupleOp); - -REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") - .Device(DEVICE_XLA_GPU) - .HostMemory("base_handle") - .HostMemory("shape_index") - .HostMemory("output_handle"), - XRTSubTupleOp); -REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") - .Device(DEVICE_XLA_CPU) - .HostMemory("base_handle") - .HostMemory("shape_index") - .HostMemory("output_handle"), - XRTSubTupleOp); - -REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") - .Device(DEVICE_XLA_GPU) - .HostMemory("tuple_description") - .HostMemory("input_handles") - .HostMemory("output_handle"), - XRTMakeTupleOp); -REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") - .Device(DEVICE_XLA_CPU) - .HostMemory("tuple_description") - .HostMemory("input_handles") - .HostMemory("output_handle"), - XRTMakeTupleOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") - .Device(DEVICE_XLA_GPU) - .HostMemory("handle") - .HostMemory("literal"), - XRTReadLiteralOp); -REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") - .Device(DEVICE_XLA_CPU) - .HostMemory("handle") - .HostMemory("literal"), - XRTReadLiteralOp); - -REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") - .Device(DEVICE_XLA_GPU) - .HostMemory("handle") - .HostMemory("literal") - .HostMemory("output_handle"), - XRTWriteLiteralOp); -REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") - .Device(DEVICE_XLA_CPU) - .HostMemory("handle") - .HostMemory("literal") - .HostMemory("output_handle"), - XRTWriteLiteralOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") - .Device(DEVICE_XLA_GPU) - .HostMemory("handle") - .HostMemory("literal"), - XRTReadLiteralOp); -REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") - .Device(DEVICE_XLA_CPU) - .HostMemory("handle") - .HostMemory("literal"), - XRTReadLiteralOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") - .Device(DEVICE_XLA_GPU) - .HostMemory("handles") - .HostMemory("tensors"), - XRTReadToTensorOp); -REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") - .Device(DEVICE_XLA_CPU) - .HostMemory("handles") - .HostMemory("tensors"), - XRTReadToTensorOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") - .Device(DEVICE_XLA_GPU) - .HostMemory("handle"), - XRTReleaseAllocationOp); -REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") - .Device(DEVICE_XLA_CPU) - .HostMemory("handle"), - XRTReleaseAllocationOp); - -REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), - XRTReleaseAllAllocationsOp); -REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), - XRTReleaseAllAllocationsOp); - -REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_GPU), - XRTCompactAllocationsOp); -REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU), - XRTCompactAllocationsOp); - -REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU), - XRTMetricsCollectOp); - -REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_GPU), - XRTMemoryInfoOp); -REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_CPU), - XRTMemoryInfoOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h deleted file mode 100644 index 5faf034af023d9..00000000000000 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ /dev/null @@ -1,784 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for allocating XLA literals in device memory and managing handles -// that refer to them. - -#ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ -#define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ - -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/literal_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "xla/client/local_client.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/status_macros.h" -#include "xla/statusor.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/compiler/xrt/xrt_device.h" -#include "tensorflow/compiler/xrt/xrt_memory_manager.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" -#include "tensorflow/compiler/xrt/xrt_state.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/monitoring/percentile_sampler.h" -#include "tensorflow/core/lib/monitoring/timed.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -// Helper functions for templated ops. -class XRTStateHelpers { - public: - // The Status return value allows us to use the - // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an - // OpKernel::Compute method. - static Status MakeLiteral(const xla::LiteralProto& proto, - xla::Literal* literal) { - TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); - return OkStatus(); - } - - // ParseTupleNode is the recursive function used to parse a recursive - // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the - // tuple shape where every leaf is an existing allocation. As a side-effect it - // fills in input_vector by looking up allocations from handles in the - // input_tensor_list as they are referenced by nodes in the proto. - static Status ParseTupleNode( - const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list, - std::vector* input_vector, - xla::Shape* shape, ResourceMgr* rm) { - if (tuple_node.tuples_size() > 0) { - // This is an internal node in the proto so descend recursively. - xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType({}); - std::vector subshapes(tuple_node.tuples_size(), dummy); - *xla::ShapeUtil::GetMutableSubshape(shape, {}) = - xla::ShapeUtil::MakeTupleShape(subshapes); - for (int i = 0; i < tuple_node.tuples_size(); ++i) { - TF_RETURN_IF_ERROR(ParseTupleNode( - tuple_node.tuples(i), input_tensor_list, input_vector, - xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm)); - } - } else { - // This is a leaf node in the proto so look up the referenced input. - int input_index = tuple_node.input_index(); - if (input_index < 0 || input_index >= input_vector->size()) { - return errors::InvalidArgument("Invalid tuple input index ", - input_index, ": MakeTuple has ", - input_vector->size(), " inputs."); - } - bool release_this_input = tuple_node.release_input_handle(); - XRTTupleAllocation::ExpandedTupleInput& input = - input_vector->at(input_index); - if (input.allocation != nullptr && - (input.release_allocation_after_use || release_this_input)) { - return errors::InvalidArgument( - "Invalid tuple tree: input index ", input_index, - " is repeated but release_input_handle is true."); - } - if (input.allocation == nullptr) { - // We haven't dereferenced this handle yet. - TF_RET_CHECK( - TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape())); - int64_t key = input_tensor_list[input_index].scalar()(); - TF_ASSIGN_OR_RETURN(input.allocation, - XRTMemoryManager::Get(rm)->Lookup(key)); - input.release_allocation_after_use = release_this_input; - } - } - return OkStatus(); - } - - // Parses a xrt::XLATupleNode proto recursively and returns the corresponding - // ShapeTree where each leaf is an allocation corresponding to a handle in - // input_tensor_list. The ordinal of one of the allocations is returned in - // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with - // no leaves, device_ordinal will always be filled in by a successful call to - // ParseTupleTree. - static Status ParseTupleTree( - const xrt::XLATupleNode& tuple_tree_root, - const OpInputList& input_tensor_list, - std::vector* input_vector, - xla::ShapeTree* tuple_shape_tree, - int* device_ordinal, ResourceMgr* rm) { - // First get the shape of the 'spine' of the new tuple, where every leaf is - // an existing allocation. As a side-effect dereference the input handles - // into allocations in input_vector. - xla::Shape tuple_tree_shape; - TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list, - input_vector, &tuple_tree_shape, rm)); - // Make the shape tree of allocations where the shape is the spine and each - // leaf is one of the allocations looked up in input_vector. Internal nodes - // have nullptr allocations. - *tuple_shape_tree = xla::ShapeTree( - tuple_tree_shape); - tuple_shape_tree->ForEachMutableElement( - [&](const xla::ShapeIndex& index, - XRTTupleAllocation::ExpandedTupleInput* element) { - if (tuple_shape_tree->IsLeaf(index)) { - // Find the matching leaf in the proto tree. - const xrt::XLATupleNode* tuple_node = &tuple_tree_root; - for (int i = 0; i < index.size(); ++i) { - tuple_node = &tuple_node->tuples(index[i]); - } - // Copy the appropriate input allocation to the leaf of the - // tuple_shape_tree. - int input_index = tuple_node->input_index(); - *element = input_vector->at(input_index); - CHECK(element->release_allocation_after_use == - tuple_node->release_input_handle()); - // We just need to know the device_ordinal of one of the - // allocations. We will validate later that they are all the same. - *device_ordinal = (*element).allocation->device_ordinal(); - } - }); - return OkStatus(); - } -}; - -// Op that allocates memory for a literal and transfers it to the device. -template -class XRTAllocateOp : public OpKernel { - public: - explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTAllocateOp() override = default; - XRTAllocateOp(const XRTAllocateOp&) = delete; - XRTAllocateOp& operator=(const XRTAllocateOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTAllocateOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell()); - - const Tensor& allocation_info = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()), - errors::Internal("allocation input should be a string scalar")); - xrt::XLAAllocation allocation_proto; - OP_REQUIRES(ctx, - ParseFromTString(allocation_info.scalar()(), - &allocation_proto), - errors::InvalidArgument( - "Unable to parse allocation input to XLAAllocation")); - - xla::Literal literal; - OP_REQUIRES_OK( - ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - literal, memory_manager.get(), device_ref.backend(), - device_ref.device_ordinal(), &allocation, - device_ref.allocator())); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = memory_manager->Register(allocation); - ctx->set_output(0, output); - } -}; - -// Op that allocates uninitialized memory on the device for a tensor of -// a particular shape. -template -class XRTAllocateUninitializedOp : public OpKernel { - public: - explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_)); - OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_)); - } - ~XRTAllocateUninitializedOp() override = default; - XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete; - XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) = - delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTAllocateUninitializedOp::Compute"; - auto timed = - monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell()); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateUninitialized( - xla_shape_, memory_manager.get(), - device_ref.backend(), device_ref.device_ordinal(), - &allocation, device_ref.allocator())); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = memory_manager->Register(allocation); - ctx->set_output(0, output); - } - - private: - DataType dtype_; - TensorShape tf_shape_; - xla::Shape xla_shape_; -}; - -// Op that allocates memory for a tensor (with optional layout) and transfers it -// to the device, returning an allocation handle. -template -class XRTAllocateFromTensorOp : public OpKernel { - public: - explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - bool make_tuple = false; - OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); - std::vector minor_to_major; - if (ctx->HasAttr("layouts")) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); - } - OP_REQUIRES( - ctx, tf_shapes_.size() == dtypes_.size(), - errors::InvalidArgument("shapes and dtypes must be the same length")); - std::vector xla_shapes; - xla_shapes.reserve(tf_shapes_.size()); - for (int i = 0; i < tf_shapes_.size(); i++) { - xla::Shape xla_shape; - OP_REQUIRES_OK( - ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); - xla_shapes.push_back(std::move(xla_shape)); - } - if (xla_shapes.size() > 1 || make_tuple) { - shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); - } else { - shape_.Swap(&xla_shapes.front()); - } - if (!minor_to_major.empty()) { - xla::Shape shape_with_layouts; - OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, - /*layout_func=*/nullptr, - &shape_with_layouts)); - shape_.Swap(&shape_with_layouts); - } - } - - ~XRTAllocateFromTensorOp() override = default; - XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; - XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTAllocateFromTensorOp::Compute"; - auto timed = - monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell()); - - OpInputList values; - OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); - OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), - errors::InvalidArgument( - "Wrong number of inputs to XRTAllocateFromTensor: ", - values.size(), " vs. ", tf_shapes_.size())); - - std::vector tensors_data; - for (size_t i = 0; i < values.size(); ++i) { - const Tensor& input_tensor = values[i]; - OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], - errors::InvalidArgument( - "Input tensor type and input dtype do not match")); - // We allow the requested on-device shape to differ from the shape of the - // input tensor, as long as they have the same number of elements. - OP_REQUIRES( - ctx, - input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), - errors::InvalidArgument( - "Input tensor must have the number of elements specified " - "in the matching input shape: ", - input_tensor.shape().num_elements(), " vs. ", - tf_shapes_[i].num_elements(), " at index ", i)); - tensors_data.push_back( - static_cast(DMAHelper::base(&input_tensor))); - } - // Use the buffer straight out of the input tensors to create the literal. - xla::BorrowingLiteral literal = - shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) - : xla::BorrowingLiteral(tensors_data.front(), shape_); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - XRTTupleAllocation* allocation; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - literal, memory_manager.get(), device_ref.backend(), - device_ref.device_ordinal(), &allocation, - device_ref.allocator())); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = memory_manager->Register(allocation); - ctx->set_output(0, output); - } - - private: - std::vector tf_shapes_; - DataTypeVector dtypes_; - xla::Shape shape_; -}; - -// Op that takes a tuple handle input and returns a handle to a sub-tuple of the -// input. -template -class XRTSubTupleOp : public OpKernel { - public: - explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTSubTupleOp() override = default; - XRTSubTupleOp(const XRTSubTupleOp&) = delete; - XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTSubTupleOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell()); - - const Tensor& handle_tensor = ctx->input(0); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), - errors::Internal("computation input should be an int64 scalar")); - int64_t allocation_handle = handle_tensor.scalar()(); - - const Tensor& subtuple_info = ctx->input(1); - OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(subtuple_info.shape()), - errors::Internal("tuple index input should be an int32 vector")); - xla::ShapeIndex shape_index; - for (int i = 0; i < subtuple_info.dim_size(0); ++i) { - shape_index.push_back(subtuple_info.vec()(i)); - } - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - RefPtr allocation; - OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); - - if (discard_) { - VLOG(2) << "Releasing handle " << allocation_handle; - OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); - } - - XRTTupleAllocation* suballocation; - OP_REQUIRES_OK( - ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index, - &suballocation, !discard_)); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = memory_manager->Register(suballocation); - ctx->set_output(0, output); - } -}; - -// Op that allocates memory for a literal and transfers it to the device. -template -class XRTMakeTupleOp : public OpKernel { - public: - explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTMakeTupleOp() override = default; - XRTMakeTupleOp(const XRTMakeTupleOp&) = delete; - XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTMakeTupleOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell()); - - const Tensor& tuple_info = ctx->input(0); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(tuple_info.shape()), - errors::Internal("tuple description input should be a string scalar")); - xrt::XLATupleNode tuple_proto; - OP_REQUIRES( - ctx, ParseFromTString(tuple_info.scalar()(), &tuple_proto), - errors::InvalidArgument("Unable to parse tuple input to XLATupleNode")); - - OpInputList arg_list; - OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list)); - - // For each input, the allocation it corresponds to and a flag indicating - // whether or not it should be released, i.e. discarded from the resource - // manager. One ref on each allocation is owned by this vector, and freed on - // exit. - std::vector input_vector( - arg_list.size()); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - xla::ShapeTree tuple_shape_tree; - // device_ordinal is filled in by ParseTupleTree with the ordinal of one of - // the allocations. It is guaranteed that there is at least on allocation in - // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that - // all the allocations are on the same device. - int device_ordinal; - OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree( - tuple_proto, arg_list, &input_vector, - &tuple_shape_tree, &device_ordinal, rm)); - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK( - ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - XRTTupleAllocation* output_allocation; - OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple( - memory_manager.get(), device_ref.backend(), - device_ref.device_ordinal(), tuple_shape_tree, - &output_allocation, device_ref.allocator())); - RefPtr output_ptr(output_allocation); - for (int i = 0; i < input_vector.size(); ++i) { - if (input_vector[i].release_allocation_after_use) { - OP_REQUIRES_OK( - ctx, memory_manager->Release(arg_list[i].scalar()())); - } - } - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = - memory_manager->Register(std::move(output_ptr)); - ctx->set_output(0, output); - } -}; - -// Op that reads a device-resident tuple to host memory and returns it as a -// literal. -template -class XRTReadLiteralOp : public OpKernel { - public: - explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTReadLiteralOp() override = default; - XRTReadLiteralOp(const XRTReadLiteralOp&) = delete; - XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTReadLiteralOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell()); - - const Tensor& handle_tensor = ctx->input(0); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), - errors::Internal("computation input should be an int64 scalar")); - int64_t allocation_handle = handle_tensor.scalar()(); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - RefPtr allocation; - OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); - - if (discard_) { - VLOG(2) << "Releasing handle " << allocation_handle; - OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); - } - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( - ctx, allocation->device_ordinal(), &device_ref)); - - xla::Literal literal(allocation->on_host_shape()); - OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal)); - xla::LiteralProto literal_proto = literal.ToProto(); - - Tensor output(DT_STRING, TensorShape({})); - SerializeToTString(literal_proto, &output.scalar()()); - ctx->set_output(0, output); - } -}; - -// Op that reads a device-resident tuple to host memory and returns it as a -// literal. -template -class XRTReadToTensorOp : public OpKernel { - public: - explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); - } - ~XRTReadToTensorOp() override = default; - XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; - XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTReadToTensorOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell()); - - const Tensor& handle_tensor = ctx->input(0); - // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not - // just scalars.) - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), - errors::Internal("computation input should be an int64 scalar")); - int64_t allocation_handle = handle_tensor.scalar()(); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - RefPtr allocation; - OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); - - if (discard_) { - VLOG(2) << "Releasing handle " << allocation_handle; - OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); - } - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( - ctx, allocation->device_ordinal(), &device_ref)); - - xla::Shape shape = allocation->on_host_shape(); - int output = 0; - Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( - &shape, - [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { - if (subshape->IsTuple()) return OkStatus(); - - xla::PrimitiveType xla_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( - ctx->expected_output_dtype(output), &xla_type)); - if (xla_type != subshape->element_type()) { - return errors::InvalidArgument( - "Type mismatch between buffer type (", subshape->ToString(), - ") and tensor type (", - DataTypeString(ctx->expected_output_dtype(output)), - ") for output tensor ", output); - } - - TensorShape output_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); - - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - ctx->allocate_output(output, output_shape, &output_tensor)); - - XRTTupleAllocation* sub; - TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - allocation.get(), index, &sub, /*alias_parent_allocation=*/true)); - core::ScopedUnref sub_unref(sub); - - xla::MutableBorrowingLiteral literal; - TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( - xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, - &literal)); - TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal)); - - ++output; - return OkStatus(); - }); - OP_REQUIRES_OK(ctx, status); - } - bool discard_; - DataTypeVector dtypes_; -}; - -// Op that writes a new literal value into device-resident memory. -template -class XRTWriteLiteralOp : public OpKernel { - public: - explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTWriteLiteralOp() override = default; - XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; - XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTWriteLiteralOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell()); - - const Tensor& handle_tensor = ctx->input(0); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), - errors::Internal("computation input should be an int64 scalar")); - int64_t allocation_handle = handle_tensor.scalar()(); - - const Tensor& literal_info = ctx->input(1); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), - errors::Internal("literal input should be a string scalar")); - xla::LiteralProto literal_proto; - OP_REQUIRES( - ctx, ParseFromTString(literal_info.scalar()(), &literal_proto), - errors::InvalidArgument( - "Unable to parse allocation input to LiteralProto")); - xla::Literal literal; - OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - RefPtr allocation; - OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); - - // We are guaranteed that the underlying device object won't be deleted out - // from under us, while the ScopedRef is live. - typename DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( - ctx, allocation->device_ordinal(), &device_ref)); - OP_REQUIRES_OK(ctx, - allocation->WriteLiteral(device_ref.backend(), literal)); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = allocation_handle; - ctx->set_output(0, output); - } -}; - -// Op that discards a handle to device memory. -template -class XRTReleaseAllocationOp : public OpKernel { - public: - explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTReleaseAllocationOp() override = default; - XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete; - XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTReleaseAllocationOp::Compute"; - auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell()); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - - RefPtr memory_manager = XRTMemoryManager::Get(rm); - const Tensor& allocation_handle = ctx->input(0); - auto flat_keys = allocation_handle.flat(); - for (int64_t i = 0; i < flat_keys.size(); ++i) { - int64_t key = flat_keys(i); - OP_REQUIRES_OK(ctx, memory_manager->Release(key)); - VLOG(2) << "Released allocation handle " << key; - } - } -}; - -// Op that discards a handle to device memory. -template -class XRTReleaseAllAllocationsOp : public OpKernel { - public: - explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - ~XRTReleaseAllAllocationsOp() override = default; - XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; - XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = - delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; - auto timed = - monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell()); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - XRTMemoryManager::Get(rm)->ReleaseAllAllocations(); - } -}; - -template -class XRTCompactAllocationsOp : public OpKernel { - public: - explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTCompactAllocationsOp() override = default; - XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete; - XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "XRTCompactAllocationsOp::Compute"; - auto timed = - monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell()); - - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - RefPtr memory_manager = XRTMemoryManager::Get(rm); - class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); - OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations( - device_ref.backend(), device_ref.device_ordinal(), - device_ref.allocator())); - } -}; - -template -class XRTMemoryInfoOp : public OpKernel { - public: - explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~XRTMemoryInfoOp() override = default; - XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete; - XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete; - - void Compute(OpKernelContext* ctx) override { - auto kernel_fn = [&]() -> Status { - VLOG(1) << "XRTMemoryInfoOp::Compute"; - - class DeviceAccessor::ScopedRef device_ref; - TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref)); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * stream_executor, - device_ref.backend()->stream_executor(device_ref.device_ordinal())); - int64_t mem_free = -1; - int64_t mem_total = -1; - if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) { - VLOG(2) << "Device " << ctx->device()->name() - << " does not expose memory information"; - } - xrt::MemoryInfo mem_info; - mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1); - mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1); - - Tensor output(DT_STRING, TensorShape({})); - output.scalar()() = mem_info.SerializeAsString(); - ctx->set_output(0, output); - return OkStatus(); - }; - OP_REQUIRES_OK(ctx, kernel_fn()); - } -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc deleted file mode 100644 index fffb703dd84c2d..00000000000000 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -REGISTER_OP("XRTCompile") - .Input("computation: string") - .Output("handle: int64") - .Output("program_shape: string") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - c->set_output(1, c->UnknownShapeOfRank(1)); - return OkStatus(); - }) - .Doc( - R"( -Reads a computation proto, compiles it, and places it in the global compilation -cache. - -'computation' is a serialized xrt::XLAComputation proto. -'handle' is an identifier that can be used in other ops to refer to the -computation. -)"); - -REGISTER_OP("XRTReleaseCompilationHandle") - .Input("handle: int64") - .SetShapeFn(tensorflow::shape_inference::NoOutputs) - .Doc( - R"( -Discards one or more computation handles from the compilation cache. -The handle(s) cannot be subsequently used. - -'handle' is an ID (or vector of IDs) returned from a XRTCompile Op. -)"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc deleted file mode 100644 index 6f485d82cbecc4..00000000000000 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -REGISTER_OP("XRTExecute") - .Attr("Ninputs: int >= 0") - .Input("computation_handle: int64") - .Input("execution_config: string") - .Input("input_handles: Ninputs * int64") - .Output("output_handle: int64") - .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector input_handle_shapes; - TF_RETURN_IF_ERROR(c->input("input_handles", &input_handle_shapes)); - for (size_t i = 0; i < input_handle_shapes.size(); ++i) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR( - c->WithRankAtMost(input_handle_shapes[i], 1, &unused)); - } - return tensorflow::shape_inference::ScalarShape(c); - }) - .Doc( - R"( -Runs a previously-compiled computation on a core. If -execution_config.release_input_handles is true, the input handles are invalid -after this op runs. - -'computation_handle' is an id returned by XRTCompile. -'execution_config' is a serialized xrt::TPUExecutionConfig proto. -'input_handles' is a list of ids of allocations, one per input to the compiled -computation. -'output_handle' is an identifier for the result of the compiled computation. -'Ninputs' is the number of input handles. -)"); - -REGISTER_OP("XRTExecuteChained") - .Input("execution_plan: string") - .Input("execution_config: string") - .Output("output_handle: int64") - .SetShapeFn([](shape_inference::InferenceContext* c) { - return tensorflow::shape_inference::ScalarShape(c); - }) - .Doc( - R"( -Runs a sequence of previously-compiled computations on a core. -The 'execution_plan' input is a serialized xrt::XRTChainedExecutePlan proto -describing the post-order of the chained execution. -The 'execution_config' input is a serialized xrt::XRTChainedExecuteConfig -proto describing the configuration for the chained execution operation. -Returns one of more int64 handles to the XRT device data generated by the -chained execution. -)"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc deleted file mode 100644 index 5a831d14284633..00000000000000 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ /dev/null @@ -1,247 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -static bool Initialized = [] { - tensorflow::RequestXlaDevicesCreation(); - return true; -}(); - -REGISTER_OP("XRTAllocate") - .Input("allocation: string") - .Output("handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Reads a literal proto and transfers it to device memory. - -'allocation' is a serialized xrt::XLAAllocation proto. -'handle' is an id that can be used in other ops to refer to the allocation. -)"); - -REGISTER_OP("XRTAllocateUninitialized") - .Output("handle: int64") - .Attr("dtype: type") - .Attr("shape: shape") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Allocates a tensor to hold the specified shape in device memory. The values -in the tensor are left uninitialized. - -shape: The shapes which the tensor should have on device. - -handle: An id that can be used in other ops to refer to the allocation. -)"); - -REGISTER_OP("XRTAllocateFromTensor") - .Input("inputs: dtypes") - .Output("handle: int64") - .Attr("dtypes: list(type)") - .Attr("shapes: list(shape)") - .Attr("layouts: list(int) = []") - .Attr("make_tuple: bool = false") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Reads a list of tensors with optional layouts, and transfers it to device -memory. - -inputs: The tensors holding the input data. -shapes: The shapes which the tensors should have on device. The i-th shape -corresponds to the i-th input. The shapes, together with the (optional) -layouts, helps creating the fully qualified shape of the data on the device. -The shapes can differ from the corresponding input one, as long as the total -number of elements matches. In other words, it is possible to feed an input -tensor with shape {8} and have a corresponding shape {2,2,2}. -layouts: A vector holding the requested layout in minor-to-major sequence. -If empty, the default layout will be used. -For a tuple, the layouts vector holds a linearized minor-to-major numbers -for all the tuple leaves, in the order they appear within the tuple. -The elements within the layouts sequence corresponding to a given tuple -subshape can be set to -1, to leave such subshape to the default shape. -handle: An id that can be used in other ops to refer to the allocation. -)"); - -REGISTER_OP("XRTSubTuple") - .Input("base_handle: int64") - .Input("shape_index: int32") - .Output("output_handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Returns a handle to a sub-tuple of an allocated tuple. - -'base_handle' is the id of the on-device allocation. -'shape_index' is a vector of integers describing an XLA ShapeIndex. -'output_handle' is an id that can be used in other ops to refer to the -sub-tuple. -)"); - -REGISTER_OP("XRTSubTupleAndRelease") - .Input("base_handle: int64") - .Input("shape_index: int32") - .Output("output_handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Returns a handle to a sub-tuple of an allocated tuple, and releases the handle -of the input tuple. - -'base_handle' is the id of the on-device allocation. -'shape_index' is a vector of integers describing an XLA ShapeIndex. -'output_handle' is an id that can be used by other ops to refer to the -sub-tuple. -)"); - -REGISTER_OP("XRTMakeTuple") - .Attr("Ninputs: int") - .Input("tuple_description: string") - .Input("input_handles: Ninputs * int64") - .Output("output_handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Returns a handle to a new allocation constructed by assembling existing -allocations in a tuple. - -'tuple_description' is a serialized xrt::XLATupleNode proto describing the -shape of the output tuple, and whether each input handle should be aliased or -released. -'input_handles' is a list of input handles to assemble into the output tuple. -'output_handle' is an id that can be used by other ops to refer to the new -tuple. -'Ninputs' is the number of input handles. -)"); - -REGISTER_OP("XRTReadLiteral") - .Input("handle: int64") - .Output("literal: string") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Copies an allocated tuple from device memory and returns it as a literal. - -'handle' is the id returned from the Op that produced the on-device allocation. -'literal' is a serialized xla::LiteralProto proto. -)"); - -REGISTER_OP("XRTWriteLiteral") - .Input("handle: int64") - .Input("literal: string") - .Output("output_handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Copies the input literal into the device memory pointed to by handle. -Returns the handle itself. - -'handle' is the id returned from the Op that produced the on-device allocation. -'literal' is a serialized xla::LiteralProto proto to be written to device memory. -)"); - -REGISTER_OP("XRTReadLiteralAndRelease") - .Input("handle: int64") - .Output("literal: string") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Copies an allocated tuple from device memory, and returns it as a literal, and -releases the handle. - -'handle' is the id returned from the Op that produced the on-device allocation. -'literal' is a serialized xla::LiteralProto proto. -)"); - -REGISTER_OP("XRTReadToTensor") - .Input("handles: int64") - .Attr("release_handles: bool = False") - .Attr("dtypes: list(type)") - .Output("tensors: dtypes") - .SetShapeFn(tensorflow::shape_inference::UnknownShape) - .Doc( - R"( -Copies allocated values from device memory and returns them as zero or more -Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned. -In general, the tensors returned for a handle correspond to an in-order traversal -of a the tuple-tree value referenced by the handle. - -'handles' contains ids returned from Ops that produced on-device allocations. -At present, only a single (scalar) handle is supported. -'dtypes' are the expected types for each `Tensor` to be returned. If the -expected and actual tensor types do not match, an error is returned. -'release_handles': if True, `handles` are released. -'tensors' are the output Tensors. -)"); - -REGISTER_OP("XRTReleaseAllocationHandle") - .Input("handle: int64") - .SetShapeFn(tensorflow::shape_inference::NoOutputs) - .Doc( - R"( -Discards one or more device memory handles. The handle(s) cannot be subsequently -used. - -'handle' is the ID (or a vector of IDs) returned from the Op that produced the -on-device allocation. -)"); - -REGISTER_OP("XRTReleaseAllAllocations") - .SetShapeFn(tensorflow::shape_inference::NoOutputs) - .Doc( - R"( -Discards all the XRT allocations. All the client held handles will be invalid. -)"); - -REGISTER_OP("XRTCompactAllocations") - .SetShapeFn(tensorflow::shape_inference::NoOutputs) - .Doc( - R"( -Runs a device memory compaction cycle. This copies the device data behind the -currently alive allocation handles into host memory, releases the device memory -backing the handles, and re-allocate and send back the data to the device. -This operation helps with device memory fragmentation. -)"); - -REGISTER_OP("XRTMetricsCollect") - .Input("request: string") - .Output("result: string") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Reads the selected metric values from the metrics collection registry. - -'request' is a serialized xrt::XRTMetricsCollect proto. -'result' is a serialized xrt::MetricsReport proto. -)"); - -REGISTER_OP("XRTMemoryInfo") - .Output("result: string") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) - .Doc( - R"( -Returns the memory information of the device this op executes on/ - -'result' is a serialized xrt::MemoryInfo proto. -)"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD deleted file mode 100644 index 0139bd1fc6a076..00000000000000 --- a/tensorflow/compiler/xrt/tests/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test") -load( - "//tensorflow/core/platform:build_config_root.bzl", - "tf_cuda_tests_tags", -) -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//learning/brain:__subpackages__"], - licenses = ["notice"], -) - -cc_library( - name = "raw_api_test_lib", - testonly = 1, - srcs = [ - "raw_api_test.cc", - ], - deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xrt:xrt_proto_cc", - "//tensorflow/compiler/xrt:xrt_server", - "//tensorflow/compiler/xrt/cc:xrt_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:tensorflow_opensource", - "//tensorflow/core:test", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:status", - "@local_xla//xla:literal", - "@local_xla//xla:literal_util", - "@local_xla//xla:shape_util", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:executable_build_options", - "@local_xla//xla/client:local_client", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/service:platform_util", - "@local_xla//xla/stream_executor:platform", - ], -) - -tf_cc_test( - name = "raw_api_test_cpu", - size = "medium", - srcs = [], - args = [ - "--xla_test_device=XLA_CPU", - "--xla_platform=CPU", - ], - deps = [ - ":raw_api_test_lib", - "//tensorflow/compiler/jit:xla_cpu_device", - ], -) - -tf_cuda_cc_test( - name = "raw_api_test_gpu", - size = "medium", - srcs = [], - args = [ - "--xla_test_device=XLA_GPU", - "--xla_platform=GPU", - ], - tags = tf_cuda_tests_tags() + [ - "no_cuda_asan", # TODO(b/171319142): re-enable. - ], - deps = [ - ":raw_api_test_lib", - "//tensorflow/compiler/jit:xla_gpu_device", - ], -) diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc deleted file mode 100644 index 10f32f44aa2236..00000000000000 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ /dev/null @@ -1,2291 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/array_ops.h" -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/compiler/tf2xla/literal_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "xla/client/client_library.h" -#include "xla/client/executable_build_options.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/layout.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/platform_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/platform.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/tstring.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/util/command_line_flags.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace tensorflow { -namespace { - -xla::XlaComputation ReturnDynamicR1() { - xla::XlaBuilder builder("ReturnDynamicR1"); - auto p0 = xla::Parameter(&builder, 0, - xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); - auto p1 = xla::Parameter(&builder, 1, - xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); - auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), - "P2"); - auto sum = xla::Add(p0, p1); - auto pad_sum = xla::SetDimensionSize(sum, p2, 0); - return builder.Build(pad_sum).value(); -} - -xla::XlaComputation ReturnDynamicR2() { - xla::XlaBuilder builder("ReturnDynamicR2"); - auto p0 = xla::Parameter(&builder, 0, - xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0"); - auto p1 = xla::Parameter(&builder, 1, - xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1"); - auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), - "P2"); - auto sum = xla::Add(p0, p1); - auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0); - auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1); - return builder.Build(pad_sum_dim1).value(); -} - -xla::XlaComputation AcceptDynamicR1() { - xla::XlaBuilder builder("AcceptDynamicR1"); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_shape.set_dynamic_dimension(0, true); - auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); - auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1"); - auto sum = xla::Add(p0, p1); - return builder.Build(sum).value(); -} - -xla::XlaComputation AcceptDynamicR2() { - xla::XlaBuilder builder("AcceptDynamicR2"); - xla::Shape dyn_shape; - dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); - dyn_shape.set_dynamic_dimension(1, true); - auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); - auto negate = xla::Neg(p0); - return builder.Build(negate).value(); -} - -xla::XlaComputation ReturnDynamicR1Tuple() { - xla::XlaBuilder builder("ReturnDynamicR1Tuple"); - auto p0 = xla::Parameter(&builder, 0, - xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); - auto p1 = xla::Parameter(&builder, 1, - xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); - auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), - "P2"); - auto sum = xla::Add(p0, p1); - auto sub = xla::Sub(p0, p1); - auto one = xla::One(&builder, xla::S32); - auto pad_sum = xla::SetDimensionSize(sum, p2, 0); - auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0); - auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub}); - return builder.Build(tuple).value(); -} - -xla::XlaComputation AcceptDynamicR1Tuple() { - xla::XlaBuilder builder("AcceptDynamicR1"); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_shape.set_dynamic_dimension(0, true); - xla::Shape tuple_shape = - xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); - xla::Shape nest_tuple_shape = - xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); - auto p = xla::Parameter(&builder, 0, tuple_shape, "P0"); - auto p0 = xla::GetTupleElement(p, 0); - auto p1 = xla::GetTupleElement(p, 1); - auto sum = xla::Add(p0, p1); - return builder.Build(sum).value(); -} - -template -xla::LiteralProto CreateR0(T v) { - auto array = xla::LiteralUtil::CreateR0(v); - return array.ToProto(); -} - -tensorflow::SessionOptions GetSessionOptions() { - tensorflow::SessionOptions options; - // Disable optimizations for static graph to allow calls to Session::Extend. - options.config.mutable_experimental()->set_disable_optimize_for_static_graph( - true); - return options; -} - -class XrtClientSession : public ClientSession { - public: - explicit XrtClientSession(const Scope& scope) - : ClientSession(scope, GetSessionOptions()) { - auto clear_all = ops::XRTReleaseAllAllocations(scope); - std::vector outputs; - TF_CHECK_OK(Run(ClientSession::FeedType(), {}, {clear_all}, &outputs)); - } -}; - -string* xla_test_device_ptr; // initial value set in main() -string* xla_platform_ptr; // initial value set in main() - -string DeviceFromFlag() { - string xla_test_device = *xla_test_device_ptr; - return absl::StrCat("/device:", xla_test_device, ":0"); -} - -std::vector GetAttrLayout(absl::Span minor_to_mayor) { - std::vector layout; - for (auto dim : minor_to_mayor) { - layout.push_back(static_cast(dim)); - } - return layout; -} - -xla::LiteralProto TwoElementTuple() { - auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); - auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); - return tuple.ToProto(); -} - -xla::LiteralProto BasedTwoElementTuple(float base) { - auto array = xla::LiteralUtil::CreateR1({base, base + 1}); - auto matrix = xla::LiteralUtil::CreateR2( - {{base + 2, base + 3}, {base + 4, base + 5}}); - auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); - return tuple.ToProto(); -} - -xla::LiteralProto ScalarLiteral() { - auto scalar = xla::LiteralUtil::CreateR0(12.0f); - return scalar.ToProto(); -} - -xla::LiteralProto NestedTuple() { - auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); - auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); - auto scalar = xla::LiteralUtil::CreateR0(12.0f); - auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar}); - return nested.ToProto(); -} - -xla::LiteralProto MakeTuple0() { - auto scalar = xla::LiteralUtil::CreateR0(12.0f); - auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); - auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); - auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple}); - auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0}); - return nested1.ToProto(); -} - -xla::LiteralProto FloatVector(absl::Span v) { - auto array = xla::LiteralUtil::CreateR1(v); - return array.ToProto(); -} - -xla::LiteralProto FloatMatrix( - std::initializer_list> v, - const xla::Layout& layout) { - auto array = xla::LiteralUtil::CreateR2WithLayout(v, layout); - return array.ToProto(); -} - -xla::Literal ReadOutputLiteral(const std::vector& outputs, size_t idx) { - xla::LiteralProto response; - CHECK(ParseFromTString(outputs[idx].scalar()(), &response)); - return xla::Literal::CreateFromProto(response).value(); -} - -bool CompareLiteralProtos(const xla::LiteralProto& a, - const xla::LiteralProto& b) { - auto l_a = xla::Literal::CreateFromProto(a).value(); - auto l_b = xla::Literal::CreateFromProto(b).value(); - bool equal = l_a == l_b; - if (!equal) { - LOG(INFO) << "LiteralProtos don't match:\n" - << a.DebugString() << "\n!=\n" - << b.DebugString(); - } - return equal; -} - -bool CompareLiteralToLiteralProto(const xla::Literal& a, - const xla::LiteralProto& b) { - auto l_b = xla::Literal::CreateFromProto(b).value(); - bool equal = a == l_b; - if (!equal) { - LOG(INFO) << "Literal and LiteralProto don't match:\n" - << a.ToProto().DebugString() << "\n!=\n" - << b.DebugString(); - } - return equal; -} - -bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) { - bool equal = a == b; - if (!equal) { - LOG(INFO) << "Literals don't match:\n" - << a.ToProto().DebugString() << "\n!=\n" - << b.ToProto().DebugString(); - } - return equal; -} - -xla::XlaComputation OnePlusTwo() { - xla::XlaBuilder builder("OnePlusTwo"); - auto c0 = xla::ConstantR0(&builder, 1.0f); - auto c1 = xla::ConstantR0(&builder, 2.0f); - xla::Add(c0, c1); - return builder.Build().value(); -} - -xla::XlaComputation AddAndScale() { - xla::XlaBuilder builder("AddAndScale"); - auto p0 = xla::Parameter(&builder, 0, - xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0"); - auto p1 = xla::Parameter(&builder, 1, - xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1"); - auto sum = xla::Add(p0, p1); - auto c = xla::ConstantR0(&builder, 3.0f); - xla::Mul(sum, c); - return builder.Build().value(); -} - -xla::XlaComputation SubAndScale() { - xla::XlaBuilder builder("SubAndScale"); - auto p0 = xla::Parameter(&builder, 0, - xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0"); - auto p1 = xla::Parameter(&builder, 1, - xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1"); - auto sum = xla::Sub(p0, p1); - auto c = xla::ConstantR0(&builder, 11.0f); - xla::Mul(sum, c); - return builder.Build().value(); -} - -xla::XlaComputation Dot() { - xla::XlaBuilder builder("Dot"); - auto p0 = xla::Parameter( - &builder, 0, - xla::ShapeUtil::MakeShapeWithDenseLayout(xla::F32, {2, 2}, {0, 1}), "P0"); - auto p1 = xla::Parameter( - &builder, 1, - xla::ShapeUtil::MakeShapeWithDenseLayout(xla::F32, {2, 1}, {0, 1}), "P1"); - xla::DotDimensionNumbers ddn; - ddn.add_lhs_contracting_dimensions(1); - ddn.add_rhs_contracting_dimensions(0); - xla::DotGeneral(p0, p1, ddn); - return builder.Build().value(); -} - -xla::XlaComputation AddS64() { - xla::XlaBuilder builder("AddS64"); - auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}), - "P0"); - auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}), - "P1"); - xla::Add(p0, p1); - return builder.Build().value(); -} - -xla::XlaComputation AddAndTuple() { - xla::XlaBuilder builder("AddAndTuple"); - auto p0 = xla::Parameter(&builder, 0, - xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0"); - auto p1 = xla::Parameter(&builder, 1, - xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1"); - auto sum = xla::Add(p0, p1); - xla::Tuple(&builder, {sum}); - return builder.Build().value(); -} - -xla::XlaComputation AddAndSubTuple() { - xla::XlaBuilder builder("AddAndSubTuple"); - auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), - "P0"); - auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), - "P1"); - auto sum = xla::Add(p0, p1); - auto sub = xla::Sub(p0, p1); - xla::Tuple(&builder, {sum, sub}); - return builder.Build().value(); -} - -xla::XlaComputation BroadcastComputation(const xla::Shape& shape, - absl::Span dimensions) { - xla::XlaBuilder builder("BroadcastComputation"); - auto p0 = xla::Parameter(&builder, 0, shape, "P0"); - xla::Broadcast(p0, dimensions); - return builder.Build().value(); -} - -xla::XlaComputation IsEqualComputation(const xla::Shape& shape) { - xla::XlaBuilder builder("IsEqualComputation"); - auto p0 = xla::Parameter(&builder, 0, shape, "P0"); - auto p1 = xla::Parameter(&builder, 1, shape, "P1"); - auto cmp = - xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type())); - auto icmp = xla::ConvertElementType(cmp, xla::S32); - xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32), - xla::CreateScalarAddComputation(xla::S32, &builder)); - return builder.Build().value(); -} - -void StoreComputationSnapshot(const xla::XlaComputation& computation, - xla::HloSnapshot* dst) { - auto snapshot = computation.Snapshot().value(); - *dst = *snapshot; -} - -xla::ProgramShape XlaCompiledProgramShape( - const xla::XlaComputation& computation, - const xla::ProgramShape& input_program_shape) { - se::Platform* platform = - xla::PlatformUtil::GetPlatform(*xla_platform_ptr).value(); - xla::LocalClient* client = - xla::ClientLibrary::GetOrCreateLocalClient(platform).value(); - xla::ExecutableBuildOptions exec_options; - exec_options.set_result_layout(input_program_shape.result()); - std::vector parameters_shapes; - for (int64_t i = 0; i < input_program_shape.parameters_size(); ++i) { - parameters_shapes.push_back(&input_program_shape.parameters(i)); - } - std::vector> local_executables = - client->Compile(computation, parameters_shapes, exec_options).value(); - EXPECT_EQ(local_executables.size(), 1); - std::unique_ptr local_executable = - std::move(local_executables[0]); - return local_executable->executable() - ->module() - .entry_computation() - ->ComputeProgramShape(); -} - -TEST(RawApiTest, AllocFromTensor) { - xla::Literal literal = - xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); - Tensor tensor; - TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - std::vector layout = - GetAttrLayout(literal.shape().layout().minor_to_major()); - ops::XRTAllocateFromTensor::Attrs alloc_attrs = - ops::XRTAllocateFromTensor::Layouts(layout); - auto handle = - ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); - auto read_back = ops::XRTReadLiteralAndRelease(root, handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); -} - -TEST(RawApiTest, AllocUninitialized) { - xla::Literal literal = - xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); - Tensor tensor; - TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - std::vector layout = - GetAttrLayout(literal.shape().layout().minor_to_major()); - - auto allocate_op = - ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape()); - - Tensor handle; - std::vector outputs; - XrtClientSession session(root); - // Allocate the tensor - { - TF_EXPECT_OK(session.Run({allocate_op}, &outputs)); - handle = outputs[0]; - } - - // Make sure it has the expected shape - { - auto read_back_op = ops::XRTReadLiteral(root, handle); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run({read_back_op}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - xla::LiteralProto read_back_literal; - EXPECT_TRUE( - ParseFromTString(outputs[0].scalar()(), &read_back_literal)); - Tensor read_back_tensor; - TF_ASSERT_OK(LiteralToHostTensor( - xla::Literal::CreateFromProto(read_back_literal).value(), DT_FLOAT, - &read_back_tensor)); - - // The shape should be the same as 'tensor', but we don't have any - // expectation about the value of the tensors yet since it is uninitialized - EXPECT_EQ(tensor.shape(), read_back_tensor.shape()); - } - - // Make sure we can write to it - xla::LiteralProto new_literal = - xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto(); - { - auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), - new_literal.SerializeAsString()); - auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value); - TF_ASSERT_OK(root.status()); - TF_EXPECT_OK(session.Run({write_op}, &outputs)); - } - - // Now read it back - { - auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle); - TF_ASSERT_OK(root.status()); - TF_EXPECT_OK(session.Run({read_back_op}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralProtos(response, new_literal)); - } -} - -TEST(RawApiTest, AllocFromTensorTuple) { - xla::Literal literal0 = - xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); - xla::Literal literal1 = - xla::LiteralUtil::CreateR2({{14.0f, -5.0f}, {16.0f, 17.0f}}); - xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); - Tensor tensor0; - TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); - Tensor tensor1; - TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1)); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - std::vector layout = GetShapeLayoutVector(literal.shape()).value(); - ops::XRTAllocateFromTensor::Attrs alloc_attrs = - ops::XRTAllocateFromTensor::Layouts(layout); - auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1}, - {tensor0.shape(), tensor1.shape()}, - alloc_attrs); - auto read_back = ops::XRTReadLiteralAndRelease(root, handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); -} - -TEST(RawApiTest, AllocFromTensorTupleSingle) { - xla::Literal literal0 = - xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); - xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0}); - Tensor tensor0; - TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - std::vector layout = GetShapeLayoutVector(literal.shape()).value(); - ops::XRTAllocateFromTensor::Attrs alloc_attrs = - ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true); - auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()}, - alloc_attrs); - auto read_back = ops::XRTReadLiteralAndRelease(root, handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); -} - -TEST(RawApiTest, AllocFromTensorRelayout) { - xla::Literal literal = - xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); - Tensor tensor; - TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - // Use inverse array layout with the tensor data above. - std::vector layout({0, 1}); - ops::XRTAllocateFromTensor::Attrs alloc_attrs = - ops::XRTAllocateFromTensor::Layouts(layout); - auto handle = - ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); - auto read_back = ops::XRTReadLiteralAndRelease(root, handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - // We have sent literal's data (in array layout) with a attribute layout - // {0,1}, so the expected literal read from device needs to be changed - // accordingly. - xla::Literal expected_literal = - xla::LiteralUtil::CreateR2({{4.0f, 6.0f}, {5.0f, 7.0f}}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response)); -} - -TEST(RawApiTest, AllocAndRewrite) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = - xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value = - ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); - auto handle = ops::XRTAllocate(root, value); - auto read_back = ops::XRTReadLiteral(root, handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); - EXPECT_EQ(outputs.size(), 2); - - int64_t allocation_handle = outputs[1].scalar()(); - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); - - xla::LiteralProto new_literal = - xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); - auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), - new_literal.SerializeAsString()); - auto write_op = - ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); - TF_ASSERT_OK(root.status()); - TF_EXPECT_OK(session.Run({write_op}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - EXPECT_EQ(allocation_handle, outputs[0].scalar()()); - - auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); - TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto new_response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &new_response)); - EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); - - Tensor release_tensor(DT_INT64, TensorShape({1})); - release_tensor.flat()(0) = allocation_handle; - - auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); - TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs)); -} - -TEST(RawApiTest, AllocReleaseMany) { - xrt::XLAAllocation alloc1; - *alloc1.mutable_value() = - xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); - xrt::XLAAllocation alloc2; - *alloc2.mutable_value() = - xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto(); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value1 = - ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString()); - auto value2 = - ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString()); - auto handle1 = ops::XRTAllocate(root, value1); - auto handle2 = ops::XRTAllocate(root, value2); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); - EXPECT_EQ(outputs.size(), 2); - - int64_t allocation_handle1 = outputs[0].scalar()(); - int64_t allocation_handle2 = outputs[1].scalar()(); - - Tensor release_tensor(DT_INT64, TensorShape({2})); - release_tensor.flat()(0) = allocation_handle1; - release_tensor.flat()(1) = allocation_handle2; - - auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); - TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs)); -} - -TEST(RawApiTest, CompileAndReleaseMany) { - xrt::XLAComputation c1; - auto config1 = c1.mutable_config(); - auto shapes1 = config1->mutable_program_shape(); - *shapes1->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes1->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes1->mutable_result() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot()); - - xrt::XLAComputation c2; - auto config2 = c2.mutable_config(); - auto shapes2 = config2->mutable_program_shape(); - *shapes2->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes2->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes2->mutable_result() = - xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) - .ToProto(); - StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot()); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto computation1 = - ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString()); - auto c_handle1 = ops::XRTCompile(root, computation1); - auto computation2 = - ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString()); - auto c_handle2 = ops::XRTCompile(root, computation2); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs)); - EXPECT_EQ(outputs.size(), 2); - - int64_t compilation_handle1 = outputs[0].scalar()(); - int64_t compilation_handle2 = outputs[1].scalar()(); - - Tensor release_tensor(DT_INT64, TensorShape({2})); - release_tensor.flat()(0) = compilation_handle1; - release_tensor.flat()(1) = compilation_handle2; - - auto release = ops::XRTReleaseCompilationHandle(root, release_tensor); - TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs)); -} - -TEST(RawApiTest, AllocAndClearAll) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = - xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value = - ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); - auto handle = ops::XRTAllocate(root, value); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({handle}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - int64_t allocation_handle = outputs[0].scalar()(); - - auto clear_all = ops::XRTReleaseAllAllocations(root); - - TF_EXPECT_OK( - session.Run(ClientSession::FeedType(), {}, {clear_all}, &outputs)); - EXPECT_EQ(outputs.size(), 0); - - auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); - EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), - error::Code::NOT_FOUND); -} - -TEST(RawApiTest, ReadAndWriteState) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = TwoElementTuple(); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value = - ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); - auto handle = ops::XRTAllocate(root, value); - auto read_back = ops::XRTReadLiteral(root, handle); - auto release = ops::XRTReleaseAllocationHandle( - root.WithControlDependencies(read_back), handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK( - session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); -} - -TEST(RawApiTest, ReadAndWriteStateAutoFree) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = TwoElementTuple(); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value = - ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); - auto handle = ops::XRTAllocate(root, value); - auto read_back = ops::XRTReadLiteralAndRelease(root, handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); -} - -TEST(RawApiTest, SubBuffer) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = NestedTuple(); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value = - ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); - auto base_handle = ops::XRTAllocate(root, value); - auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0}); - auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1}); - auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0}); - auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0); - auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1); - auto sub_00 = ops::XRTSubTupleAndRelease( - root.WithControlDependencies( - {sub_0.output_handle.op(), sub_1.output_handle.op()}), - base_handle, index_00); - auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0); - auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1); - auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs)); - - auto base_literal = xla::Literal::CreateFromProto(alloc.value()).value(); - auto base_elements = base_literal.DecomposeTuple(); - auto nested_0_elements = base_elements[0].Clone().DecomposeTuple(); - xla::LiteralProto response_0; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response_0)); - EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0)); - xla::LiteralProto response_1; - EXPECT_TRUE(ParseFromTString(outputs[1].scalar()(), &response_1)); - EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1)); - xla::LiteralProto response_00; - EXPECT_TRUE(ParseFromTString(outputs[2].scalar()(), &response_00)); - EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00)); -} - -TEST(RawApiTest, MakeTuple) { - xrt::XLAAllocation alloc_0; - *alloc_0.mutable_value() = TwoElementTuple(); - xrt::XLAAllocation alloc_1; - *alloc_1.mutable_value() = ScalarLiteral(); - - // The trivial tuple that just forwards its input and releases it. - xrt::XLATupleNode desc_0; - desc_0.set_input_index(0); - desc_0.set_release_input_handle(true); - - xrt::XLATupleNode desc_1; - auto subdesc_10 = desc_1.add_tuples(); - auto subdesc_11 = desc_1.add_tuples(); - subdesc_10->set_input_index(0); - auto subdesc_110 = subdesc_11->add_tuples(); - subdesc_110->set_input_index(0); - auto subdesc_111 = subdesc_11->add_tuples(); - subdesc_111->set_input_index(1); - - xrt::XLATupleNode desc_2; - auto subdesc_20 = desc_2.add_tuples(); - auto subdesc_21 = desc_2.add_tuples(); - subdesc_20->set_input_index(1); - subdesc_20->set_release_input_handle(true); - subdesc_21->set_input_index(0); - subdesc_21->set_release_input_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto value_0 = - ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString()); - auto handle_0 = ops::XRTAllocate(root, value_0); - auto value_1 = - ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString()); - auto handle_1 = ops::XRTAllocate(root, value_1); - auto tuple_0 = - ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString()); - auto handle_2 = - ops::XRTMakeTuple(root, tuple_0, {static_cast(handle_0)}); - // handle_0 has now been released. - auto tuple_1 = - ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString()); - auto handle_3 = ops::XRTMakeTuple( - root, tuple_1, - {static_cast(handle_1), static_cast(handle_2)}); - auto tuple_2 = - ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString()); - // Make sure this runs after handle_3 has completed, since it will free - // handle_1 and handle_2. - auto handle_4 = ops::XRTMakeTuple( - root.WithControlDependencies(handle_3), tuple_2, - {static_cast(handle_1), static_cast(handle_2)}); - // handle_1 and handle_2 have now been released. - - auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3); - auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs)); - xla::LiteralProto response_0; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response_0)); - xla::LiteralProto response_1; - EXPECT_TRUE(ParseFromTString(outputs[1].scalar()(), &response_1)); - - auto expected_0 = MakeTuple0(); - EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0)); - auto expected_1 = NestedTuple(); - EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1)); -} - -TEST(RawApiTest, ExecuteChainedOpByOp) { - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - - auto make_computation = [](const std::function& fn) { - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot()); - return c.SerializeAsString(); - }; - - auto c_add_scale = make_computation(AddAndScale); - auto c_sub_scale = make_computation(SubAndScale); - - auto c_add_scale_op = ops::XRTCompile( - root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale)); - auto c_sub_scale_op = ops::XRTCompile( - root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale)); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK( - session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs)); - EXPECT_EQ(outputs.size(), 2); - - int64_t c_add_scale_handle = outputs[0].scalar()(); - int64_t c_sub_scale_handle = outputs[1].scalar()(); - - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({8.0f, 5.0f}); - - auto p0_handle = ops::XRTAllocate( - root, - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString())); - auto p1_handle = ops::XRTAllocate( - root, - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString())); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto result0 = ops::XRTExecute(root, Input(c_add_scale_handle), e_config, - {Output(p0_handle), Output(p1_handle)}); - auto result1 = ops::XRTExecute(root, Input(c_sub_scale_handle), e_config, - {Output(p0_handle), Output(p1_handle)}); - auto result = ops::XRTExecute(root, Input(c_add_scale_handle), e_config, - {result0.output_handle, result1.output_handle}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, ExecuteChained) { - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - - auto make_computation = [](const std::function& fn) { - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot()); - return c.SerializeAsString(); - }; - - auto c_add_scale = make_computation(AddAndScale); - auto c_sub_scale = make_computation(SubAndScale); - - auto c_add_scale_op = ops::XRTCompile( - root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale)); - auto c_sub_scale_op = ops::XRTCompile( - root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale)); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK( - session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs)); - EXPECT_EQ(outputs.size(), 2); - - int64_t c_add_scale_handle = outputs[0].scalar()(); - int64_t c_sub_scale_handle = outputs[1].scalar()(); - - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({8.0f, 5.0f}); - - auto p0_handle_op = ops::XRTAllocate( - root, - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString())); - auto p1_handle_op = ops::XRTAllocate( - root, - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString())); - - TF_EXPECT_OK(session.Run({p0_handle_op, p1_handle_op}, &outputs)); - EXPECT_EQ(outputs.size(), 2); - - int64_t p0_handle = outputs[0].scalar()(); - int64_t p1_handle = outputs[1].scalar()(); - - xrt::XRTChainedExecuteConfig config; - auto config_const = - ops::Const(root.WithDevice("/device:CPU:0"), config.SerializeAsString()); - - xrt::XRTChainedExecutePlan plan; - xrt::XRTChainedExecuteOp* op; - xrt::XRTChainedExecuteOp::Input* input; - xrt::XRTChainedExecuteOp::Output* output; - - // Index 0 - op = plan.add_ops(); - op->set_data_handle(p0_handle); - - // Index 1 - op = plan.add_ops(); - op->set_data_handle(p1_handle); - - // Index 2 - op = plan.add_ops(); - op->set_computation_handle(c_add_scale_handle); - input = op->add_inputs(); - input->set_op_index(0); - input = op->add_inputs(); - input->set_op_index(1); - - // Index 3 - op = plan.add_ops(); - op->set_computation_handle(c_sub_scale_handle); - input = op->add_inputs(); - input->set_op_index(0); - input = op->add_inputs(); - input->set_op_index(1); - - // Index 4 - op = plan.add_ops(); - op->set_computation_handle(c_add_scale_handle); - input = op->add_inputs(); - input->set_op_index(2); - input = op->add_inputs(); - input->set_op_index(3); - output = op->add_outputs(); - output->set_result_index(0); - - auto plan_const = - ops::Const(root.WithDevice("/device:CPU:0"), plan.SerializeAsString()); - auto result = ops::XRTExecuteChained(root, plan_const, config_const); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run({result}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - auto handles_vec = outputs[0].vec(); - EXPECT_EQ(handles_vec.size(), 1); - - auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(0))); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, CompileAndExecute) { - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({8.0f, 5.0f}); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - - xla::ProgramShapeProto program_shape; - EXPECT_TRUE(ParseFromTString(outputs[1].vec()(0), &program_shape)); - EXPECT_EQ(program_shape.parameters_size(), 2); -} - -TEST(RawApiTest, DynamicR1Test) { - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f}); - xrt::XLAAllocation p2; - *p2.mutable_value() = CreateR0(2); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_shape.set_dynamic_dimension(0, true); - *shapes->mutable_result() = dyn_shape.ToProto(); - StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Scope cpu_root = root.WithDevice("/device:CPU:0"); - auto e_config = ops::Const(cpu_root, e.SerializeAsString()); - auto computation = ops::Const(cpu_root, c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); - auto p2_handle = ops::XRTAllocate(root, p2_value); - auto result = ops::XRTExecute( - root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); - auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, DynamicR2Test) { - xrt::XLAAllocation p0; - *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f}, - {1.5f, 2.5f, 3.0f, -2.0f}}) - .ToProto(); - xrt::XLAAllocation p1; - *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f}, - {1.2f, -1.6f, 2.8f, 1.24f}}) - .ToProto(); - xrt::XLAAllocation p2; - *p2.mutable_value() = CreateR0(2); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); - dyn_shape.set_dynamic_dimension(0, true); - dyn_shape.set_dynamic_dimension(1, true); - *shapes->mutable_result() = dyn_shape.ToProto(); - StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Scope cpu_root = root.WithDevice("/device:CPU:0"); - auto e_config = ops::Const(cpu_root, e.SerializeAsString()); - auto computation = ops::Const(cpu_root, c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); - auto p2_handle = ops::XRTAllocate(root, p2_value); - auto result = ops::XRTExecute( - root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); - auto expected = xla::LiteralUtil::CreateR2({{2.0f, 1.0f}, {2.7, 0.9}}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, DynamicR1TupleTest) { - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f}); - xrt::XLAAllocation p2; - *p2.mutable_value() = CreateR0(2); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_shape.set_dynamic_dimension(0, true); - *shapes->mutable_result() = - xla::ShapeUtil::MakeTupleShape( - {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape}) - .ToProto(); - StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Scope cpu_root = root.WithDevice("/device:CPU:0"); - auto e_config = ops::Const(cpu_root, e.SerializeAsString()); - auto computation = ops::Const(cpu_root, c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); - auto p2_handle = ops::XRTAllocate(root, p2_value); - auto result = ops::XRTExecute( - root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); - - auto expected0 = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); - auto expected1 = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f, 0.0f}); - auto expected2 = xla::LiteralUtil::CreateR1({0.0f, 3.0f, 1.0f}); - auto expected = - xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, AcceptDynamicR1TupleTest) { - if (*xla_test_device_ptr == "XLA_CPU" || *xla_test_device_ptr == "XLA_GPU") { - // XLA_CPU and XLA_GPU has shape check set to kCompileTime. - return; - } - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); - - xrt::XLATupleNode tuple_desc; - auto subdesc_10 = tuple_desc.add_tuples(); - auto subdesc_11 = tuple_desc.add_tuples(); - subdesc_10->set_input_index(0); - subdesc_10->set_release_input_handle(true); - subdesc_11->set_input_index(1); - subdesc_11->set_release_input_handle(true); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_input_shape.set_dynamic_dimension(0, true); - xla::Shape dyn_tuple_shape = - xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape}); - *shapes->add_parameters() = dyn_tuple_shape.ToProto(); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_shape.set_dynamic_dimension(0, true); - *shapes->mutable_result() = dyn_shape.ToProto(); - StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Scope cpu_root = root.WithDevice("/device:CPU:0"); - auto e_config = ops::Const(cpu_root, e.SerializeAsString()); - auto computation = ops::Const(cpu_root, c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - - auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"), - tuple_desc.SerializeAsString()); - auto t0_handle = ops::XRTMakeTuple( - root, tuple_0, - {static_cast(p0_handle), static_cast(p1_handle)}); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {static_cast(t0_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); - - auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, AcceptDynamicR1Test) { - if (*xla_test_device_ptr == "XLA_CPU" || *xla_test_device_ptr == "XLA_GPU") { - // XLA_CPU and XLA_GPU has shape check set to kCompileTime. - return; - } - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_input_shape.set_dynamic_dimension(0, true); - *shapes->add_parameters() = dyn_input_shape.ToProto(); - *shapes->add_parameters() = dyn_input_shape.ToProto(); - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); - dyn_shape.set_dynamic_dimension(0, true); - *shapes->mutable_result() = dyn_shape.ToProto(); - StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Scope cpu_root = root.WithDevice("/device:CPU:0"); - auto e_config = ops::Const(cpu_root, e.SerializeAsString()); - auto computation = ops::Const(cpu_root, c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); - auto allocate_op_0 = ops::XRTAllocate(root, p0_value); - auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); - auto allocate_op_1 = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {Output(allocate_op_0), Output(allocate_op_1)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); - - auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, AcceptDynamicR2Test) { - xrt::XLAAllocation p0; - *p0.mutable_value() = - xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}}) - .ToProto(); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - // Compile time expects ascending layout. - xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); - dyn_shape.set_dynamic_dimension(1, true); - *shapes->add_parameters() = dyn_shape.ToProto(); - - *shapes->mutable_result() = dyn_shape.ToProto(); - StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Scope cpu_root = root.WithDevice("/device:CPU:0"); - auto e_config = ops::Const(cpu_root, e.SerializeAsString()); - auto computation = ops::Const(cpu_root, c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto result = - ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); - - auto expected = xla::LiteralUtil::CreateR2( - {{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({8.0f, 5.0f}); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"), - {Output(p0_handle), Output(p1_handle)}); - auto result = - ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - - xla::ProgramShapeProto program_shape; - EXPECT_TRUE(ParseFromTString(outputs[1].vec()(0), &program_shape)); - EXPECT_EQ(program_shape.parameters_size(), 2); -} - -TEST(RawApiTest, CompileWithXlaReturnShapes) { - xla::XlaBuilder builder("XrtXlaShapes"); - auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128}); - auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5}); - // Clear layouts to signal XLA we are ready to get whatever are coming out of - // the compilation process. - xla::LayoutUtil::ClearLayout(&input_shape); - xla::LayoutUtil::ClearLayout(&kernel_shape); - auto param_shape = - xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape}); - auto param = xla::Parameter(&builder, 0, param_shape, "param"); - auto input = xla::GetTupleElement(param, 0); - auto kernel = xla::GetTupleElement(param, 1); - xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame); - TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build()); - - auto result_shape = xla_computation.GetProgramShape().value().result(); - // Clear the result shape layout to tell XLA we are accepting whatever are - // coming out of the compilation process. - xla::LayoutUtil::ClearLayout(&result_shape); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = param_shape.ToProto(); - *shapes->mutable_result() = result_shape.ToProto(); - StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot()); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape}, - {release}, &outputs)); - - xla::ProgramShapeProto program_shape_proto; - EXPECT_TRUE( - ParseFromTString(outputs[0].vec()(0), &program_shape_proto)); - xla::ProgramShape program_shape(program_shape_proto); - EXPECT_EQ(program_shape.parameters_size(), 1); - - VLOG(2) << "Param: " - << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0)); - VLOG(2) << "Result: " - << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); - - xla::ProgramShape xla_program_shape = - XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); - EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( - xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), - xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) - .layout())); - EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( - xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), - xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) - .layout())); - EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( - program_shape.result().layout(), xla_program_shape.result().layout())); -} - -TEST(RawApiTest, DotGeneralWithLayoutTest) { - auto layout = xla::LayoutUtil::MakeLayout({0, 1}); - - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShapeWithDenseLayout(xla::F32, {2, 2}, {0, 1}) - .ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShapeWithDenseLayout(xla::F32, {2, 1}, {0, 1}) - .ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeShapeWithDenseLayout(xla::F32, {2, 1}, {0, 1}) - .ToProto(); - StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = - xla::LiteralUtil::CreateR2WithLayout({{18.0f}, {44.0f}}, layout); - - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, CompileAndExecuteZeroArg) { - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot()); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - std::initializer_list({})); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR0(3.0f); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, CompileAndExecuteReturnTuple) { - xrt::XLAAllocation p0; - *p0.mutable_value() = FloatVector({1.0f, 2.0f}); - xrt::XLAAllocation p1; - *p1.mutable_value() = FloatVector({8.0f, 5.0f}); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) - .ToProto(); - StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto sum = xla::LiteralUtil::CreateR1({9.0f, 7.0f}); - auto expected = xla::LiteralUtil::MakeTuple({&sum}); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); -} - -TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { - xrt::XLAAllocation p0; - *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); - - xrt::XLAAllocation p1; - *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), - xla::ShapeUtil::MakeShape(xla::F32, {})}) - .ToProto(); - StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - e.set_return_exploded_tuple(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle)}); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({result}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - auto handles_vec = outputs.front().vec(); - EXPECT_EQ(handles_vec.size(), 2); - - const float kResults[2] = {15.0f, 9.0f}; - for (int64_t i = 0; i < handles_vec.size(); ++i) { - auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); - std::vector voutputs; - TF_EXPECT_OK(session.Run({read_back}, &voutputs)); - EXPECT_EQ(voutputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(voutputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR0(kResults[i]); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - } -} - -TEST(RawApiTest, LeakCompilationReference) { - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->add_parameters() = - xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) - .ToProto(); - StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); -} - -TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { - xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2}); - xla::Shape shape = - xla::ShapeUtil::MakeTupleShape({element_shape, element_shape}); - xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape( - {element_shape, element_shape, element_shape, element_shape}); - xla::XlaBuilder builder("ReuseBuffer"); - auto param = xla::Parameter(&builder, 0, shape, "param"); - auto p0 = xla::GetTupleElement(param, 0); - auto p1 = xla::GetTupleElement(param, 1); - auto add = xla::Add(p0, p1); - auto sub = xla::Sub(p0, p1); - xla::Tuple(&builder, {add, sub, p0, p1}); - - // Flip the tuple literals in the input handle. - builder.SetUpAlias({1}, 0, {0}); - builder.SetUpAlias({0}, 0, {1}); - - auto computation = builder.Build().value(); - - auto literal0 = xla::LiteralUtil::CreateR1({1.0f, 2.0f}); - auto literal1 = xla::LiteralUtil::CreateR1({5.0f, 9.0f}); - auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); - - xrt::XLAAllocation param_alloc; - *param_alloc.mutable_value() = literal.ToProto(); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = shape.ToProto(); - *shapes->mutable_result() = return_shape.ToProto(); - StoreComputationSnapshot(computation, c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - XrtClientSession session(root); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto c_data = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, c_data); - auto param_value = ops::Const(root.WithDevice("/device:CPU:0"), - param_alloc.SerializeAsString()); - auto param_handle = ops::XRTAllocate(root, param_value); - TF_ASSERT_OK(root.status()); - - std::vector outputs; - TF_EXPECT_OK(session.Run({param_handle}, &outputs)); - - int64_t alloc_handle = outputs[0].scalar()(); - - // Note that we release the result handle immediately, but since we aliased - // the output buffers onto the input allocation ones (held in alloc_handle), - // we can fetch the result from there. - auto result = - ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)}); - auto read_back = ops::XRTReadLiteral(root, result); - auto release = ops::XRTReleaseAllocationHandle( - root.WithControlDependencies(read_back), result); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK( - session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); - - xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); - auto exec_literal_parts = exec_literal.DecomposeTuple(); - ASSERT_EQ(exec_literal_parts.size(), 4); - - EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0)); - EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1)); - - // Now we read back the original input handle values, which at this point - // should contain the result of the XLA computation. - auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle)); - TF_ASSERT_OK(root.status()); - auto release_handle = ops::XRTReleaseAllocationHandle( - root.WithControlDependencies(read_handle), Input(alloc_handle)); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle}, - {release_handle}, &outputs)); - - xla::Literal return_literal = ReadOutputLiteral(outputs, 0); - - auto expected_literal0 = xla::LiteralUtil::CreateR1({6.0f, 11.0f}); - auto expected_literal1 = xla::LiteralUtil::CreateR1({-4.0f, -7.0f}); - // The first element of the computation returned tuple would be the add - // (expected_literal0), but since we flipped the buffers, the sub - // (expected_literal1) should come first. - auto expected_literal = - xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0}); - - EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); -} - -TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) { - xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2}); - xla::Shape shape = - xla::ShapeUtil::MakeTupleShape({element_shape, element_shape}); - xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape( - {element_shape, element_shape, element_shape, element_shape}); - xla::XlaBuilder builder("ReuseBuffer"); - auto param = xla::Parameter(&builder, 0, shape, "param"); - auto p0 = xla::GetTupleElement(param, 0); - auto p1 = xla::GetTupleElement(param, 1); - auto add = xla::Add(p0, p1); - auto sub = xla::Sub(p0, p1); - xla::Tuple(&builder, {add, sub, p0, p1}); - - // Flip the tuple literals in the input handle. - builder.SetUpAlias({1}, 0, {0}); - builder.SetUpAlias({0}, 0, {1}); - - auto computation = builder.Build().value(); - - auto literal0 = xla::LiteralUtil::CreateR1({1, 2}); - auto literal1 = xla::LiteralUtil::CreateR1({5, 9}); - auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); - - xrt::XLAAllocation param_alloc; - *param_alloc.mutable_value() = literal.ToProto(); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = shape.ToProto(); - *shapes->mutable_result() = return_shape.ToProto(); - StoreComputationSnapshot(computation, c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - XrtClientSession session(root); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto c_data = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, c_data); - auto param_value = ops::Const(root.WithDevice("/device:CPU:0"), - param_alloc.SerializeAsString()); - auto param_handle = ops::XRTAllocate(root, param_value); - TF_ASSERT_OK(root.status()); - - std::vector outputs; - TF_EXPECT_OK(session.Run({param_handle}, &outputs)); - - int64_t alloc_handle = outputs[0].scalar()(); - - // Note that we release the result handle immediately, but since we aliased - // the output buffers onto the input allocation ones (held in alloc_handle), - // we can fetch the result from there. - auto result = - ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)}); - auto read_back = ops::XRTReadLiteral(root, result); - auto release = ops::XRTReleaseAllocationHandle( - root.WithControlDependencies(read_back), result); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK( - session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); - - xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); - auto exec_literal_parts = exec_literal.DecomposeTuple(); - ASSERT_EQ(exec_literal_parts.size(), 4); - - EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0)); - EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1)); - - // Now we read back the original input handle values, which at this point - // should contain the result of the XLA computation. - auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle)); - TF_ASSERT_OK(root.status()); - auto release_handle = ops::XRTReleaseAllocationHandle( - root.WithControlDependencies(read_handle), Input(alloc_handle)); - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle}, - {release_handle}, &outputs)); - - xla::Literal return_literal = ReadOutputLiteral(outputs, 0); - - auto expected_literal0 = xla::LiteralUtil::CreateR1({6, 11}); - auto expected_literal1 = xla::LiteralUtil::CreateR1({-4, -7}); - // The first element of the computation returned tuple would be the add - // (expected_literal0), but since we flipped the buffers, the sub - // (expected_literal1) should come first. - auto expected_literal = - xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0}); - - EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); -} - -TEST(RawApiTest, CompileAndExecuteWithS64Argument) { - xrt::XLAAllocation p0; - *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); - xrt::XLAAllocation p1; - *p1.mutable_value() = xla::LiteralUtil::CreateR0(4091934).ToProto(); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); - StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(true); - e.set_return_exploded_tuple(true); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto computation = - ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); - auto c_handle = ops::XRTCompile(root, computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - auto p1_value = - ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); - auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle.handle, e_config, - {Output(p0_handle), Output(p1_handle)}); - auto read_back = ops::XRTReadLiteralAndRelease(root, result); - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - - auto expected = xla::LiteralUtil::CreateR0(15123899); - EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - - xla::ProgramShapeProto program_shape; - EXPECT_TRUE(ParseFromTString(outputs[1].vec()(0), &program_shape)); - EXPECT_EQ(program_shape.parameters_size(), 2); - EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType( - xla::Shape(program_shape.result()), xla::S64)); -} - -// Tests the XRT device memory compaction API (XRTCompactAllocations). -TEST(RawApiTest, TestDeviceMemoryCompaction) { - static const int kNumAllocs = 32; - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - - std::vector allocs(kNumAllocs); - std::vector handle_outputs; - for (int i = 0; i < kNumAllocs; ++i) { - *allocs[i].mutable_value() = BasedTwoElementTuple(i * 4.0f); - auto value = ops::Const(root.WithDevice("/device:CPU:0"), - allocs[i].SerializeAsString()); - handle_outputs.push_back(ops::XRTAllocate(root, value)); - } - TF_ASSERT_OK(root.status()); - - XrtClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run(handle_outputs, &outputs)); - EXPECT_EQ(outputs.size(), handle_outputs.size()); - - std::vector handles; - for (auto& output : outputs) { - handles.push_back(output.scalar()()); - } - // Create holes by releasing even allocations. - std::vector handle_releases; - for (size_t i = 0; i < handles.size(); i += 2) { - handle_releases.push_back( - ops::XRTReleaseAllocationHandle(root, Input(handles[i]))); - } - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK( - session.Run(ClientSession::FeedType(), {}, handle_releases, &outputs)); - - // Run the compaction API. - auto compact_op = ops::XRTCompactAllocations(root); - TF_EXPECT_OK( - session.Run(ClientSession::FeedType(), {}, {compact_op}, &outputs)); - - // Read back the allocation left at odd indices. - std::vector read_outputs; - for (size_t i = 1; i < handles.size(); i += 2) { - read_outputs.push_back(ops::XRTReadLiteral(root, Input(handles[i]))); - } - TF_ASSERT_OK(root.status()); - - TF_EXPECT_OK(session.Run(read_outputs, &outputs)); - EXPECT_EQ(outputs.size(), read_outputs.size()); - - // Verify that everything got moved correctly and the device data matches what - // we have on record. - for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) { - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[j].scalar()(), &response)); - EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response)); - } -} - -TEST(RawApiTest, TestDeviceMemorySwap) { - const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); - // 100MB F32 tensor. - const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000}); - const int64_t tensor_size = xla::ShapeUtil::ByteSizeOf(shape); - // On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as - // maximum memory. - int64_t device_memory_size = 8LL * 1024 * 1024 * 1024; - if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") { - device_memory_size = 16LL * 1024 * 1024 * 1024; - } - - xrt::XLAAllocation p0; - *p0.mutable_value() = xla::LiteralUtil::CreateR0(0.90434).ToProto(); - - // Create a computation which broadcasts a scalar to a big tensor. - xrt::XLAComputation c_bcast; - { - auto shapes = c_bcast.mutable_config()->mutable_program_shape(); - *shapes->add_parameters() = scalar_shape.ToProto(); - *shapes->mutable_result() = shape.ToProto(); - StoreComputationSnapshot( - BroadcastComputation(scalar_shape, shape.dimensions()), - c_bcast.mutable_hlo_snapshot()); - } - - // Create a computation which compares two tensors. - xrt::XLAComputation c_equal; - { - auto shapes = c_equal.mutable_config()->mutable_program_shape(); - *shapes->add_parameters() = shape.ToProto(); - *shapes->add_parameters() = shape.ToProto(); - *shapes->mutable_result() = - xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); - StoreComputationSnapshot(IsEqualComputation(shape), - c_equal.mutable_hlo_snapshot()); - } - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - XrtClientSession session(root); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); - auto bcast_computation = - ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString()); - auto c_bcast_handle = ops::XRTCompile(root, bcast_computation); - auto equal_computation = - ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString()); - auto c_equal_handle = ops::XRTCompile(root, equal_computation); - auto p0_value = - ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); - auto p0_handle = ops::XRTAllocate(root, p0_value); - std::vector outputs; - std::vector device_handles; - - // Create more data the device can take using the broadcast computation. - int64_t num_tensors = 8 + device_memory_size / tensor_size; - for (int64_t i = 0; i < num_tensors; ++i) { - auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config, - {Output(p0_handle)}); - TF_ASSERT_OK(root.status()); - TF_ASSERT_OK(session.Run({result}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - device_handles.push_back(outputs[0].scalar()()); - } - - // Trigger computations on XRT handles to verify the swap-out/swap-in logic, - // by comparing sequential couple of tensors. - auto zero_literal = xla::LiteralUtil::CreateR0(0); - for (size_t i = 0; i + 1 < device_handles.size(); ++i) { - auto exec_op = ops::XRTExecute( - root, c_equal_handle.handle, e_config, - {Input(device_handles[i]), Input(device_handles[i + 1])}); - auto read_back = ops::XRTReadLiteral(root, exec_op); - - TF_ASSERT_OK(root.status()); - TF_ASSERT_OK(session.Run({read_back}, &outputs)); - EXPECT_EQ(outputs.size(), 1); - - xla::LiteralProto response; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &response)); - auto literal = xla::Literal::CreateFromProto(response).value(); - EXPECT_EQ(literal, zero_literal); - } -} - -TEST(RawApiTest, TestMetricsFetch) { - xrt::XRTMetricsCollect metrics; - metrics.add_metrics_regex("/tensorflow/xrt/.*"); - - Scope root = Scope::NewRootScope().WithDevice("/device:CPU:0"); - auto metrics_value = ops::Const(root, metrics.SerializeAsString()); - Output result = ops::XRTMetricsCollect(root, metrics_value); - TF_ASSERT_OK(root.status()); - - ClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({result}, &outputs)); - ASSERT_EQ(outputs.size(), 1); - - xrt::MetricsReport report; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &report)); - for (auto& metric : report.metrics()) { - EXPECT_EQ(metric.name().compare(0, 16, "/tensorflow/xrt/"), 0); - } -} - -TEST(RawApiTest, TestMemoryInfo) { - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - Output result = ops::XRTMemoryInfo(root); - TF_ASSERT_OK(root.status()); - - ClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run({result}, &outputs)); - ASSERT_EQ(outputs.size(), 1); - - xrt::MemoryInfo mem_info; - EXPECT_TRUE(ParseFromTString(outputs[0].scalar()(), &mem_info)); - EXPECT_GT(mem_info.kb_total(), 0); - EXPECT_GT(mem_info.kb_free(), 0); -} - -} // namespace - -} // namespace tensorflow - -int main(int argc, char** argv) { - tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU"); - tensorflow::xla_platform_ptr = new tensorflow::string("CPU"); - std::vector flag_list = { - tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr, - "Tensorflow device type to use for test, e.g., XLA_CPU"), - tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr, - "The XLA platform to select for the device"), - }; - tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - testing::InitGoogleTest(&argc, argv); - if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; - return 2; - } - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto deleted file mode 100644 index 826ecafc8a9273..00000000000000 --- a/tensorflow/compiler/xrt/xrt.proto +++ /dev/null @@ -1,277 +0,0 @@ -syntax = "proto3"; - -package xrt; - -import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; -import "xla/service/hlo.proto"; -import "xla/xla.proto"; -import "xla/xla_data.proto"; - -message DeviceAssignment { - message ComputationDevice { - message DeviceMeshCoordinates { - // The mesh coordinates for the device. Usually (X, Y, Z, Core), in the - // order in which they are returned in the TopologyProto. - // X = value(0) - // Y = value(1) - // Z = value(2) - // Core = value(3) - repeated int32 value = 1; - } - // As many replicas as there are in the replicated computation. - repeated DeviceMeshCoordinates replica_devices = 1; - } - // As many ComputationDevice as many there are computations (number - // of cores per replica). - repeated ComputationDevice computation_devices = 1; -} - -// Options for an XLA compilation. -message XLAComputationConfig { - // The number of replicas the computation will be run on. If this is - // default (0) it is interpreted as 1. - int32 num_replicas = 1; - // The number of "model-parallel" cores per replica. If this is - // default (0) it is interpreted as 1. - int32 num_cores_per_replica = 2; - // Optional metadata about host sends and recvs. - tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; - - // The arg/result shapes for the whole computation. - xla.ProgramShapeProto program_shape = 4; - // The arg/result shapes for each core of a model-parallel - // computation. per_core_args_and_result_shapes is optional for a - // single-core computation. - repeated xla.ProgramShapeProto per_core_program_shape = 5; - // Describes how replicated computation instances should be assigned to - // devices. There are num_cores_per_replica computations, and each one will be - // sent and executed to the set of replica device numbers described in the - // DeviceAssignment proto. - DeviceAssignment device_assignment = 6; - // The debugging options to be passed to the XLA compilation process. - xla.DebugOptions debug_options = 7; - - // Everything inside Experimental is subject to change and is not subject - // to API stability guarantees in - // https://www.tensorflow.org/guide/version_compat. - message Experimental { - message UpdateIndexPair { - int32 index = 1; - bool updated = 2; - } - - // stateful_input_indices is only useful when using XRT-compiled - // programs together with standard TensorFlow TPU execution ops, so should - // be ignored by most clients. - // - // Optionally the client can pass information about which inputs - // to the computation are updates to "stateful" quantities. Each - // element of stateful_input_indices includes an index indicating - // which input argument it corresponds to, and a bool indicating - // whether the value is updated or not. If the XRT computation is - // going to be used with a TensorFlow TPU execution op then an - // input index must be present for each input that will correspond - // to a resource variable in the execution op, and may not be - // present for any other input. - repeated UpdateIndexPair stateful_input_indices = 1; - } - - Experimental experimental = 8; -} - -// Options and XLA computation for a compilation. -message XLAComputation { - XLAComputationConfig config = 1; - xla.HloSnapshot hlo_snapshot = 2; -} - -// Literal to allocate space for, and transfer to, device memory. -message XLAAllocation { - reserved 1; - xla.LiteralProto value = 2; -} - -// Node in a tree describing a tuple constructed from input handles. A -// node is an internal node if tuples is non-empty, in which case -// input_index and release_input_handle are ignored. Otherwise a node -// is a leaf node. Each leaf XLATupleNode is the index of an input -// which corresponds to a handle that will be grafted onto the output -// tuple at that location. If release_input_handle is true that input -// handle will be released and become invalid. Inputs may be repeated -// in which case leaves of the output tuple will alias. If an input is -// repeated, release_input_handle must be false for every leaf where -// that input appears. -// -// For example, if input 0 has shape {} and input 1 has shape {2,3} -// then the XLATupleNode with structure {1,{0,1}} corresponds to a -// tuple with shape {{2,3},{{},{2,3}}}. -message XLATupleNode { - int32 input_index = 1; - bool release_input_handle = 2; - repeated XLATupleNode tuples = 3; -} - -message CommonExecutionConfig { - // The replica index this execute is driving. - int32 replica_id = 1; - // Mapping local device ordinals to global replica IDs. - // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID - repeated int32 local_replica_mapping = 2; - // The execution run ID used to correlate different XRT execute operations - // happeining in parallel from different threads. - int64 run_id = 3; -} - -// Options for an XLA execution. -message XRTExecutionConfig { - // Local device to run on. This is present because the execute Op - // may be placed on a device such as CPU or TPU_SYSTEM that - // logically manages multiple cores. - int32 device_ordinal = 1; - // Which model-parallel computation to run from the compiled bundle. - int32 core_index_in_replica = 2; - // Optional key to disambiguate between executions. This is only - // needed if multiple host send/recvs may be outstanding - // concurrently with executions. - string execution_instance_key = 3; - // If non-zero, rng_seed to reset the core with. - uint32 rng_seed = 4; - // If true, release allocation handles on the inputs after running. - bool release_input_handles = 5; - // If true, release the handle to the computation after running. - bool release_compilation_handle = 6; - // If set to true, and the result shape is a tuple, then instead of returning - // a single tuple allocation the execution will return a vector of - // allocations, one for each of the first-level elements of the result tuple. - bool return_exploded_tuple = 7; - reserved 8; - // The common configuration for XRT execute operations. - CommonExecutionConfig common_config = 9; -} - -message XRTChainedExecuteConfig { - // If non-zero, rng_seed to reset the core with. - uint32 rng_seed = 1; - // Which model-parallel computation to run from the compiled bundle. - int32 core_index_in_replica = 2; - // Optional key to disambiguate between executions. This is only needed if - // multiple host send/recvs may be outstanding concurrently with executions. - string execution_instance_key = 3; - reserved 4; - // The common configuration for XRT execute operations. - CommonExecutionConfig common_config = 5; -} - -// A single chained execute operation. An operation can either be a device data -// load, or an existing (as in, previously compiled and accessible via its int64 -// handle) XLA computation execution. -message XRTChainedExecuteOp { - // Represents an input for this operation. - message Input { - // The index within the XRTChainedExecutePlan.ops post-order of the source - // operation for this input. - int64 op_index = 1; - // The output index of the value generated by the operation at op_index. - // Zero (default value) means no index ({}) while if an indexing is - // required, output_index needs to be set to index+1. - // Thanks proto3! - int64 output_index = 2; - } - // Represents an output of the XRTChainedExecute operation, which should - // originate by the output of this operation. - message Output { - // The index in the value generated by this operation, which should be - // forwarded as XRTChainedExecute output. If output_index is zero (default - // value) the whole output will be used as result. This means that if the - // output shape is a tuple, the result will be the full tuple. Otherwise the - // real sub-tuple index will be output_index - 1. - int64 output_index = 1; - // The index in the vector of the results returned by the XRTChainedExecute - // operation, where this output should be forwarded. - int64 result_index = 2; - } - - oneof op_oneof { - // The handle to an existing XRT device data. - int64 data_handle = 1; - // The handle to an existing XRT compiled computation. - int64 computation_handle = 2; - } - // The outputs of this XRTChainedExecuteOp operation. - repeated Output outputs = 3; - // The inputs of this XRTChainedExecuteOp operation. If data_handle is set, - // there are no inputs. - repeated Input inputs = 4; -} - -// Execution plan for the XRTChainedExecute operation. -message XRTChainedExecutePlan { - // The post order with the XRT computations to be executed. - repeated XRTChainedExecuteOp ops = 1; -} - -// The message used to encode the options for the XRTMetricsCollect operation. -message XRTMetricsCollect { - // A list of regular expressions to match the metric names. Empty means to - // return all the metrics reported by the collection registry. - repeated string metrics_regex = 1; -} - -message Percentiles { - message Point { - // In the [0, 100] range. - double percentile = 1; - double value = 2; - } - - // The time (in nanoseconds) of the first sample within the samples buffer. - uint64 start_nstime = 1; - // The time (in nanoseconds) of the last sample within the samples buffer. - uint64 end_nstime = 2; - // The minimum value of the samples within the samples buffer. - double min_value = 3; - // The maximum value of the samples within the samples buffer. - double max_value = 4; - // The mean value of the samples within the samples buffer. - double mean = 5; - // The stndard deviation of the samples within the samples buffer. - double stddev = 6; - // The number samples within the samples buffer. - uint64 num_samples = 7; - // The total number of times this metrics has been posted a value to. - uint64 total_samples = 8; - // The sum of all the posted values. - double accumulator = 9; - // The percentile points reported by the metric. - repeated Point points = 10; -} - -message MetricValues { - enum UnitOfMeasure { - INVALID = 0; - NUMBER = 1; - TIME = 2; - BYTES = 3; - } - - // The metric name. - string name = 1; - - oneof values_oneof { - Percentiles percentiles_value = 2; - int64 int64_value = 3; - } - - UnitOfMeasure unit_of_measure = 4; -} - -message MetricsReport { - repeated MetricValues metrics = 1; -} - -message MemoryInfo { - // The total memory on a device, in KB. - int64 kb_total = 1; - // The free memory on a device, in KB. - int64 kb_free = 2; -} diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc deleted file mode 100644 index 7c88bad0b22bff..00000000000000 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" - -#include - -#include -#include -#include -#include - -#include "absl/synchronization/mutex.h" -#include "xla/client/local_client.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/random/random.h" - -namespace tensorflow { - -namespace { - -int64_t get_uid() { - uint64 unsigned_rand = random::New64() & INT64_MAX; - return static_cast(unsigned_rand); -} - -int64_t GetCompilationCacheSizeFromEnv() { - const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE"); - return env == nullptr ? 1024 : std::stol(env); -} - -} // namespace - -const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache"; - -XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent, - CompiledSubgraph* entry) - : parent_(parent), entry_(entry) { - entry_->Ref(); -} - -XRTCompilationCache::EntryRefImpl::~EntryRefImpl() { - parent_->DiscardEntryRef(entry_); -} - -XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() { - return XRTCompilationCacheEntry(entry_->program.get()); -} - -XRTCompilationCache::XRTCompilationCache(int max_number_of_entries) - : max_cache_entries_(max_number_of_entries) { - CHECK_GE(max_cache_entries_, 0); - VLOG(1) << "Created compilation cache max " << max_cache_entries_ - << " entries."; -} - -XRTCompilationCache::~XRTCompilationCache() { - VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()"; - // A buggy client may be holding onto a reference, or a client might have - // crashed while holding onto a reference. In either case, discard all - // outstanding client references to avoid leaking storage. - for (const auto& entry : entries_by_uid_) { - while (!entry.second->RefCountIsOne()) { - entry.second->Unref(); - } - } - while (!entries_by_last_use_.empty()) { - MarkOldestEntryForEviction(); - } - CHECK_EQ(cache_.size(), 0); - CHECK_EQ(entries_by_uid_.size(), 0); - CHECK_EQ(cache_entries_, 0); - CHECK_EQ(marked_for_eviction_entries_, 0); -} - -Status XRTCompilationCache::Release(int64_t uid) { - absl::MutexLock lock(&mu_); - auto iter = entries_by_uid_.find(uid); - - if (iter == entries_by_uid_.end()) { - return errors::NotFound("No cache entry found for uid ", uid); - } - - DiscardEntryRefLocked(iter->second); - - VLOG(1) << "After releasing entry " << uid << " refs cache is " - << cache_.size() << " entries (" - << cache_entries_ + marked_for_eviction_entries_ - << "), marked for eviction " - << (cache_.size() - entries_by_last_use_.size()) << " entries (" - << marked_for_eviction_entries_ << ")."; - - return OkStatus(); -} - -void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) { - absl::MutexLock lock(&mu_); - DiscardEntryRefLocked(entry); -} - -void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) { - if (entry->RefCountIsOne()) { - // The last reference to this entry is going away, so really delete it from - // the cache in such a way that it can't be restored by being looked up - // again. - - // Sanity-check that it has been marked for eviction. - CHECK(entries_by_last_use_.find(entry->last_use) == - entries_by_last_use_.end()); - // Update the counter tracking how much space is taken up by entries that - // are marked for eviction. - --marked_for_eviction_entries_; - - // Remove the entry from the cache. - auto erased = cache_.erase(entry->key); - if (erased == 0) { - LOG(FATAL) << "Tried to discard nonexistent cache entry"; - } - erased = entries_by_uid_.erase(entry->uid); - CHECK_EQ(erased, 1); - } - entry->Unref(); -} - -void XRTCompilationCache::MarkOldestEntryForEviction() { - CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second; - VLOG(1) << "Marking " << entry_to_mark->key << " for eviction"; - entries_by_last_use_.erase(entry_to_mark->last_use); - --cache_entries_; - ++marked_for_eviction_entries_; - // Discard the cache's reference to entry. If steps are holding onto - // references to entry it won't be deleted until the last step holding it - // completes. It stays in the cache in the meantime and can be resurrected - // by a call to CompileIfKeyAbsent if that occurs before the last reference - // expires. - DiscardEntryRefLocked(entry_to_mark); -} - -void XRTCompilationCache::LookupEntryMarkedForEviction( - CompiledSubgraph* entry) { - // The entry was previously marked for eviction (or is newly created) so - // unmark it. Add a reference (owned by the cache), update the cache size, and - // mark something old for eviction if necessary. - entry->Ref(); - --marked_for_eviction_entries_; - ++cache_entries_; - - // Mark the least-recently-used non-marked entry for eviction. Never mark the - // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1 - // which means there's only one entry not already marked for eviction), so - // that an entry persists in the cache even if it is larger than the allocated - // cache size. - while (entries_by_last_use_.size() > 1 && - cache_entries_ > max_cache_entries_) { - MarkOldestEntryForEviction(); - } -} - -XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry( - const string& key, - const std::function*)>& - initialize_program) { - CompiledSubgraph* entry = new CompiledSubgraph(); - entry->parent = this; - entry->key = key; - entry->uid = get_uid(); - // Add the entry to the cache. Once the computation has been compiled, - // UpdateEntryAfterCompilation will be called to potentially mark old entries - // that don't fit any more for eviction. - // - // At this point there is one reference to entry, which is owned by the caller - // who created the entry. A second reference, owned by the cache, will be - // added below since we leave the entry in the 'marked for eviction' state - // here. - auto cache_inserted = - cache_.insert(std::pair(key, entry)); - CHECK(cache_inserted.second); - - // Initialize the program outside the lock so that other cache operations - // can proceed during the (potentially lengthy) initialization. - Status s; - std::unique_ptr program; - { - mu_.Unlock(); - { s = initialize_program(&program); } - mu_.Lock(); - } - - // Add the entry to the uid index. - auto uid_inserted = entries_by_uid_.insert( - std::pair(entry->uid, entry)); - CHECK(uid_inserted.second); - - entry->initialized = true; - entry->initialization_status = s; - if (s.ok()) { - entry->program = std::move(program); - } - // Add the entry to marked_for_eviction_entries_ since it will be adjusted - // down again when the newly-created entry gets unmarked. - ++marked_for_eviction_entries_; - return entry; -} - -Status XRTCompilationCache::CompileIfKeyAbsent( - const string& key, int64_t* uid, - const std::function*)>& - compile_function) { - CompiledSubgraph* entry = nullptr; - - absl::MutexLock lock(&mu_); - auto iter = cache_.find(key); - - if (iter == cache_.end()) { - // The single ref on the newly-created entry is owned by the caller. - VLOG(1) << "Before adding new entry for key " << key << " cache is " - << cache_.size() << " entries (" - << cache_entries_ + marked_for_eviction_entries_ << "), " - << " marked for eviction " - << (cache_.size() - entries_by_last_use_.size()) << " entries (" - << marked_for_eviction_entries_ << ")."; - entry = InitializeEntry(key, compile_function); - } else { - VLOG(1) << "Before refreshing entry for key " << key << " cache is " - << cache_.size() << " entries (" - << cache_entries_ + marked_for_eviction_entries_ << "), " - << " marked for eviction " - << (cache_.size() - entries_by_last_use_.size()) << " entries (" - << marked_for_eviction_entries_ << ")."; - entry = iter->second; - // Make a new reference that is owned by the caller. - entry->Ref(); - // Block if necessary until the subgraph has been initialized. - mu_.Await(absl::Condition( - +[](CompiledSubgraph* e) { return e->initialized; }, entry)); - } - - // Let the caller know the uid of the entry. - *uid = entry->uid; - - // Remove the old LRU-table entry if it wasn't already marked for eviction. - auto erased = entries_by_last_use_.erase(entry->last_use); - // Update the LRU table indicating this entry is the most recently used. - entry->last_use = use_counter_++; - entries_by_last_use_[entry->last_use] = entry; - if (erased == 0) { - // The entry had been marked for eviction, or is newly created. - LookupEntryMarkedForEviction(entry); - } - - VLOG(1) << "After refreshing entry for key " << key << " cache is " - << cache_.size() << " entries (" - << cache_entries_ + marked_for_eviction_entries_ << "), " - << " marked for eviction " - << (cache_.size() - entries_by_last_use_.size()) << " entries (" - << marked_for_eviction_entries_ << ")."; - - return entry->initialization_status; -} - -Status XRTCompilationCache::Lookup( - int64_t uid, std::unique_ptr* entry) { - entry->reset(); - - absl::MutexLock lock(&mu_); - const auto iter = entries_by_uid_.find(uid); - if (iter == entries_by_uid_.end()) { - return errors::NotFound("No executable found for uid ", uid); - } - CompiledSubgraph* cache_entry = iter->second; - *entry = std::unique_ptr( - new EntryRefImpl(this, cache_entry)); - return OkStatus(); -} - -string XRTCompilationCache::DebugString() const { - return "XRTCompilationCache"; -} - -xla::StatusOr> GetOrCreateCompilationCache( - ResourceMgr* rm, int64_t max_number_of_entries) { - if (max_number_of_entries == 0) { - max_number_of_entries = GetCompilationCacheSizeFromEnv(); - } - XRTCompilationCache* cache; - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), kXRTCompilationCacheResourceName, &cache, - [&](XRTCompilationCache** new_cache) { - *new_cache = new XRTCompilationCache(max_number_of_entries); - return OkStatus(); - })); - return RefPtr(cache); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h deleted file mode 100644 index 7c89bcc5a1ecea..00000000000000 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ /dev/null @@ -1,252 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/client/local_client.h" -#include "xla/statusor.h" -#include "tensorflow/compiler/xrt/xrt_refptr.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/refcount.h" - -namespace tensorflow { - -extern const char* kXRTCompilationCacheResourceName; - -struct XRTCompilationCacheEntry { - explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable) - : executable(executable) {} - - // Returns a non-owned pointer to an immutable executable. - xla::LocalExecutable* get_executable() const { return executable; } - - private: - xla::LocalExecutable* executable; -}; - -// Base class for a reference to a cached executable. A unique_ptr to a -// XRTCompilationCacheEntryRef is returned by the cache Lookup methods below, -// and ensures the underlying executable is not garbage-collected until the -// client discards the ptr. -class XRTCompilationCacheEntryRef { - public: - virtual ~XRTCompilationCacheEntryRef() = default; - - // Returns a XRTCompilationCacheEntry that should not be used beyond the - // lifetime of the XRTCompilationCacheEntryRef. - virtual XRTCompilationCacheEntry get() = 0; -}; - -// Cache for compiled XLA executables. -// TODO(b/112646171) rationalize this with the other compilation caches. -// -// Each key identifies a unique XLA computation, and the value is executable -// generated by compiling the computation. -// -// When a computation is considered for compilation, the client calls -// -// auto key = ; -// auto compile_function = ; -// int64 uid; -// CompileIfKeyAbsent(computation_key, &uid, compile_function); -// -// where computation_key is the key computed for the computation. On success, -// uid contains an identifier that can be used to look up the executable. If the -// compiled executable were not present in the cache, compile_function would be -// called to generate it. -// -// The caller is responsible for calling Release(uid) once for every -// call to CompileIfKeyAbsent(key, ...) to discard the reference to the -// compilation results, after the caller is sure it will not look up the -// compiled executables again. -// -// Subsequently the client can call -// -// std::unique_ptr entry; -// Lookup(uid, &entry); -// auto proto = entry->get(); -// -// to access a cached executable. -class XRTCompilationCache : public ResourceBase { - public: - // There is no way in general to discover the size taken by an XLA executable, - // so the cache defaults to a specific number of entries to determine when to - // start evicting programs. TODO(b/112592410) change this if the XLA API gets - // a mechanism to query size. - explicit XRTCompilationCache(int max_number_of_entries); - ~XRTCompilationCache() override; - - // Ensures there is an entry for key present in the cache. By the time - // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache - // for key, and that entry will remain valid at least until Release is called - // on the returned uid. The first call to CompileIfKeyAbsent with a key that - // is not in the cache will evaluate compile_function to compute the value to - // use in the entry. Subsequent calls with the same key will block until - // compile_function completes. Other cache reads and inserts may proceed on - // other threads while compile_function is executing. The caller is - // responsible for calling Release(uid) to manually discard its reference to - // the compiled program, once the caller will not look up the compiled program - // again. - // - // compile_function should compile the computation represented by key and fill - // the xla::LocalExecutable into its passed argument. It should return OK - // if and only if compilation succeeds. The executable will be discarded on - // non-OK status. - Status CompileIfKeyAbsent( - const string& key, int64_t* uid, - const std::function*)>& - compile_function); - - Status Release(int64_t uid); - - // Looks up an executable corresponding to uid. On success a pointer to an - // EntryRef holding the program is returned in entry. - Status Lookup(int64_t uid, - std::unique_ptr* entry); - - string DebugString() const override; - - private: - // An entry in the compilation cache. The entry is deleted once it has been - // marked for eviction from the cache _and_ all looked-up entries have been - // released. When the entry is first created, it is uninitialized and a - // client-supplied compilation function is run outside the cache's lock to - // generate the program to be stored in the entry. Any other client that - // requests the entry will block until it has been initialized. Each entry has - // a last_use value that set from a monotonically-increasing counter in the - // cache whenever the entry is referenced. When the cache becomes full, - // entries are marked for eviction in LRU order. - struct CompiledSubgraph : public core::RefCounted { - ~CompiledSubgraph() override = default; - - XRTCompilationCache* parent = nullptr; // Not owned. - bool initialized = false; - // The Status returned by the compilation function when the entry is - // initialized. This status will be returned to any client that requests the - // entry. - Status initialization_status; - // Counter to keep track of LRU entries for the eviction policy. - int64_t last_use = -1; - // The unique key describing this entry. - string key; - // The uid describing this entry. - int64_t uid; - // The compiled payload corresponding to the key. - std::unique_ptr program; - }; - - // Wrapper for a cache entry that holds a reference to the entry until the - // wrapper is deleted. This wrapper is the concrete type of - // XRTCompilationCacheEntryRef returned by Lookup. - class EntryRefImpl : public XRTCompilationCacheEntryRef { - public: - EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry); - ~EntryRefImpl() override; - - XRTCompilationCacheEntry get() override; - - private: - XRTCompilationCache* parent_; // Not owned. - // A reference to entry_ is acquired in the contructor and released via - // parent->DiscardEntryRef in the destructor. - CompiledSubgraph* entry_; - }; - - // Releases one reference to entry. This is called by the cache when entry is - // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the - // last reference to entry is released, entry is removed from cache_. - void DiscardEntryRef(CompiledSubgraph* entry); - void DiscardEntryRefLocked(CompiledSubgraph* entry) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Marks the oldest unmarked entry for eviction. Requires that there is at - // least one such entry. - void MarkOldestEntryForEviction() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Updates datastructures to indicate that entry, which had been marked for - // eviction, has been looked up. This is called by CompileIfKeyAbsent when an - // entry is newly created, or an entry that has been marked for eviction but - // not yet evicted is looked up. - // - // First the entry is unmarked for eviction, i.e. the cache gains a reference - // to entry, entry's last_use field is set to be the most recent value of - // use_counter_ and entries_by_last_use_ is updated accordingly. - // - // Next, the size of the cache is examined to see if any other entries need to - // be marked for eviction now that entry has been unmarked. While the total - // number of unmarked cached entries is greater than max_cache_entries_, - // entries are marked for eviction in LRU order. The most recently used entry - // is never marked for eviction, so an entry larger than the max cache entries - // will remain in the cache until it is replaced by something else. - void LookupEntryMarkedForEviction(CompiledSubgraph* entry) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Creates a new entry by running initialize_program and places it in the - // cache to be looked up by key. The new entry is in the 'marked for eviction' - // state (not present in entries_by_last_use_) and the caller is expected to - // call LookupEntryMarkedForEviction after InitializeEntry. - // - // **InitializeEntry releases mu_ during the call to initialize_program.** - CompiledSubgraph* InitializeEntry( - const string& key, - const std::function*)>& - initialize_program) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // The maximum number of entries that are stored in the cache before entries - // are marked for eviction. - const int max_cache_entries_; - - mutable absl::Mutex mu_; - // The total number of entries that are stored and not marked for eviction. - int cache_entries_ TF_GUARDED_BY(mu_) = 0; - // The total number of entries that are marked for eviction. - int marked_for_eviction_entries_ TF_GUARDED_BY(mu_) = 0; - // The value to assign to the last_use field of the next entry that is looked - // up. - int64_t use_counter_ TF_GUARDED_BY(mu_) = 0; - // All the executables that can be looked up in the cache index by key. An - // entry is marked for eviction iff it is present in cache_ and not in - // entries_by_last_use_. - std::unordered_map cache_ TF_GUARDED_BY(mu_); - // All the executable entries that can be looked up in the cache indexed by - // uid. - absl::flat_hash_map entries_by_uid_ - TF_GUARDED_BY(mu_); - // Map from last_use to entry, used to mark entries for eviction in LRU - // order. If an entry's last_use counter is not present as a key in - // entries_by_last_use_ then the entry has been marked for eviction. - std::map entries_by_last_use_ TF_GUARDED_BY(mu_); -}; - -// Looks up or create an XRTCompilationCache object within the given resource -// manager, under the default container. The max_number_of_entries sets the -// maximum number of entries within the cache (which will be LRU-evicted). -// If max_number_of_entries is set to sero, the size of the cache will be -// configured using the TF_XRT_COMPILATION_CACHE_SIZE environment variable. -xla::StatusOr> GetOrCreateCompilationCache( - ResourceMgr* rm, int64_t max_number_of_entries); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc deleted file mode 100644 index 9e1d929f429194..00000000000000 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for managing access to XLA resources. - -#include "tensorflow/compiler/xrt/xrt_device.h" - -#include -#include -#include - -#include "absl/container/node_hash_map.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/status.h" -#include "tsl/framework/device_id.h" - -namespace tensorflow { -namespace { - -class ResourceMgrArena { - public: - static ResourceMgrArena* Get() { - static ResourceMgrArena* arena = new ResourceMgrArena(); - return arena; - } - - ResourceMgr* GetResourceMgr(const std::string& platform_name) { - mutex_lock lock(mutex_); - auto it = resource_managers_.find(platform_name); - if (it == resource_managers_.end()) { - it = resource_managers_.emplace(platform_name, new ResourceMgr()).first; - } - return it->second; - } - - private: - mutex mutex_; - std::map resource_managers_; -}; - -} // namespace - -/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager( - OpKernelContext* ctx, ResourceMgr** rm) { - const XlaDevice::Metadata* metadata; - TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); - *rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name()); - return OkStatus(); -} - -/* static */ xla::StatusOr> -XRTGenericDeviceAccessor::GetOrCreateCompilationCache( - OpKernelContext* ctx, int64_t max_number_of_entries) { - ResourceMgr* rm; - TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm)); - return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries); -} - -/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( - OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) { - const XlaDevice::Metadata* metadata; - TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); - if (device_ordinal != metadata->device_ordinal()) { - return errors::Internal("XRT device ordinal requested ", device_ordinal, - " on device with ordinal ", - metadata->device_ordinal()); - } - scoped_ref->Acquire(metadata->client(), device_ordinal, - metadata->platform()->Name(), ctx); - return OkStatus(); -} - -/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( - OpKernelContext* ctx, ScopedRef* scoped_ref) { - const XlaDevice::Metadata* metadata; - TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); - scoped_ref->Acquire(metadata->client(), metadata->device_ordinal(), - metadata->platform()->Name(), ctx); - return OkStatus(); -} - -/* static */ tensorflow::mutex - XRTGenericDeviceAccessor::ScopedRef::cuda_allocator_mutex_( - tensorflow::LINKER_INITIALIZED); -/* static */ absl::flat_hash_map>* - XRTGenericDeviceAccessor::ScopedRef::cuda_allocators_ = - new absl::flat_hash_map>; - -void XRTGenericDeviceAccessor::ScopedRef::Acquire( - xla::LocalClient* client, int ordinal, const std::string& platform_name, - OpKernelContext* ctx) { - client_ = client; - ordinal_ = ordinal; - allocator_ = client_->mutable_backend()->memory_allocator(); -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ - (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) - if (platform_name == "CUDA") { - // Use BfcAllocator for the CUDA. - auto stream = ctx->op_device_context()->stream(); - if (!cuda_allocators_->count(stream)) { - mutex_lock lock(cuda_allocator_mutex_); - if (!cuda_allocators_->count(stream)) { - GPUOptions gpu_options; - Allocator* raw_allocator = - GPUProcessState::singleton()->GetGPUAllocator( - tsl::TfDeviceId(ordinal_)); - (*cuda_allocators_)[stream] = - std::make_unique(raw_allocator, stream); - } - } - allocator_ = static_cast( - (*cuda_allocators_)[stream].get()); - } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -} -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h deleted file mode 100644 index de9f2c589a8bcc..00000000000000 --- a/tensorflow/compiler/xrt/xrt_device.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for keeping track of on-device state. - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/client/local_client.h" -#include "xla/stream_executor/integrations/tf_allocator_adapter.h" -#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -// This accessor is used for XLA CPU/GPU. It uses the device resource manager, -// so e.g., on multi-GPU setups the compilation cache will not be shared across -// devices. -class XRTGenericDeviceAccessor { - public: - static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); - - static xla::StatusOr> GetOrCreateCompilationCache( - OpKernelContext* ctx, int64_t max_number_of_entries); - - // We use a ScopedRef pattern here even though it's not strictly necessary, - // just so that templated uses of this and the TPU accessor class will be as - // similar as possible. - class ScopedRef { - public: - ScopedRef() = default; - ~ScopedRef() = default; - - ScopedRef(const ScopedRef&) = delete; - ScopedRef& operator=(const ScopedRef&) = delete; - - // Returns the XLA device protected by this ScopedRef. - xla::LocalClient* client() const { return client_; } - xla::Backend* backend() { return client_->mutable_backend(); } - int device_ordinal() const { return ordinal_; } - se::DeviceMemoryAllocator* allocator() { return allocator_; } - - private: - // XRTGenericDeviceAccessor::InitScopedRef is the only way to initialize - // ScopedRef. - friend class XRTGenericDeviceAccessor; - - void Acquire(xla::LocalClient* client, int ordinal, - const std::string& platform_name, OpKernelContext* ctx); - - xla::LocalClient* client_ = nullptr; - int ordinal_ = 0; - se::DeviceMemoryAllocator* allocator_ = nullptr; - static tensorflow::mutex cuda_allocator_mutex_; - static absl::flat_hash_map>* - cuda_allocators_; - }; - - static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, - ScopedRef* scoped_ref); - - static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ diff --git a/tensorflow/compiler/xrt/xrt_memory_manager.cc b/tensorflow/compiler/xrt/xrt_memory_manager.cc deleted file mode 100644 index 05325a822d9d22..00000000000000 --- a/tensorflow/compiler/xrt/xrt_memory_manager.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xrt/xrt_memory_manager.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xrt/xrt_metrics.h" -#include "tensorflow/core/lib/monitoring/timed.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/profiler/lib/traceme.h" - -namespace tensorflow { -namespace { - -// We use kDeviceBits to store the device ordinal in the handle. We store the -// device in the upper part of the int64 handle to make sure the random bits are -// in the lower part which is better when storing the handle as a key for -// unordered maps. -const int kDeviceBits = 12; - -int64_t MakeDeviceHandle(int64_t device_ordinal, int64_t rnd_value) { - const int64_t kUidMask = (static_cast(1) << (64 - kDeviceBits)) - 1; - return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask); -} - -int GetDeviceFromHandle(int64_t handle) { - return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1); -} - -} // namespace - -class XRTMemoryManager::DeviceContext { - struct Alloc { - explicit Alloc(RefPtr tuple) - : tuple(std::move(tuple)) {} - - RefPtr tuple; - }; - - using AllocList = std::list; - - public: - int64_t Register(RefPtr tuple) { - while (true) { - int64_t handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid()); - mutex_lock lock(lock_); - allocs_.emplace_front(tuple); - if (alloc_map_.emplace(handle, allocs_.begin()).second) { - return handle; - } - // The chances of hitting an existing handle are so remote, it is much - // more convenient to add to the list before, and eventually removing. - allocs_.erase(allocs_.begin()); - } - } - - bool Release(int64_t handle) { - mutex_lock lock(lock_); - auto it = alloc_map_.find(handle); - if (it == alloc_map_.end()) { - return false; - } - allocs_.erase(it->second); - alloc_map_.erase(it); - return true; - } - - RefPtr Lookup(int64_t handle) { - mutex_lock lock(lock_); - auto it = alloc_map_.find(handle); - if (it == alloc_map_.end()) { - return nullptr; - } - // LRU - allocs_.splice(allocs_.begin(), allocs_, it->second); - return it->second->tuple; - } - - void Clear() { - mutex_lock lock(lock_); - alloc_map_.clear(); - allocs_.clear(); - } - - Status CompactAllocations(XRTMemoryManager* memory_manager, - xla::Backend* backend, - se::DeviceMemoryAllocator* allocator) { - profiler::TraceMe trace_me("XRTMemoryManager::CompactAllocations", - /*level=*/2); - auto timed = monitoring::MakeTimed(xrt_metrics::GetMemoryCompactCell()); - VLOG(4) << "CompactAllocations started"; - mutex_lock lock(lock_); - Status status; - std::vector swapped; - // We are swapping out from the most recently used allocations. This is - // desirable since the most recently used will be finding themselves at the - // bottom of the allocation space. Since these are more likely to be pinned - // allocations, a further trim done by following TryFreeMemory() call will - // eventually drop the higher located allocations, with better chance of - // reducing fragmentation. - // Also, by swapping out the pinned allocations first, those will also be - // the first to be restored, and hence if we will ever find OOM on the way - // out, we would more likely be swapping in not pinned ones. - for (auto it = allocs_.begin(); it != allocs_.end(); ++it) { - // We are compacting all the allocations, so we will temporarily swap out - // even pinned allocations. - auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true); - if (!swap_result_or.ok()) { - status = swap_result_or.status(); - break; - } - if (swap_result_or.value()) { - swapped.push_back(it); - } - } - // At this point we have released all the device memory we could release. - // Load back the tuple allocations we have swapped out above. - for (auto& it : swapped) { - auto swap_result_or = - it->tuple->SwapIn(memory_manager, backend, allocator); - if (!swap_result_or.ok()) { - // If we failed to restored a pinned allocation, better to CHECK here - // than wondering why XRTTupleAllocation calls fail with errors about - // missing buffers. - CHECK(!it->tuple->IsPinned()); // Crash OK - if (status.ok()) { - status = swap_result_or.status(); - } - } - } - VLOG(4) << "CompactAllocations finished: " << status; - return status; - } - - // Tries to free size bytes by freeing some unpinned device memory. Returns - // the amount of memory which was able to free. - xla::StatusOr TryFreeMemory(xla::Backend* backend, size_t size) { - profiler::TraceMe trace_me("XRTMemoryManager::TryFreeMemory", /*level=*/2); - auto timed = monitoring::MakeTimed(xrt_metrics::GetTryFreeMemoryCell()); - mutex_lock lock(lock_); - size_t swapped_size = 0; - for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) { - TF_ASSIGN_OR_RETURN(bool swap_result, - it->tuple->SwapOut(backend, /*swap_pinned=*/false)); - if (swap_result) { - swapped_size += it->tuple->GetDeviceMemorySize(); - if (swapped_size >= size) { - break; - } - } - } - VLOG(3) << "Swapped out " << swapped_size << " bytes"; - return swapped_size; - } - - private: - static int64_t CreateUid() { - int64_t uid; - do { - uid = random::New64() & INT64_MAX; - } while (uid == InvalidKey()); - return uid; - } - - // We store Alloc records inside an std::list so we can LRU it, and - // store the list iterators within the handle map, as list iterators don't get - // invalidated by (other elements) removals or position swaps. - mutex lock_; - AllocList allocs_; - std::unordered_map alloc_map_; -}; - -XRTMemoryManager::WorkingSet::WorkingSet( - RefPtr memory_manager) - : memory_manager_(std::move(memory_manager)) {} - -XRTMemoryManager::WorkingSet::~WorkingSet() { - for (auto& tuple : pinned_tuples_) { - tuple->Unpin(); - } -} - -Status XRTMemoryManager::WorkingSet::LookupAndPin( - xla::Backend* backend, int64_t handle, - se::DeviceMemoryAllocator* allocator) { - TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle)); - TF_RETURN_IF_ERROR( - tuple->PinAndSwapIn(memory_manager_.get(), backend, allocator).status()); - pinned_tuples_.push_back(std::move(tuple)); - return OkStatus(); -} - -/* static */ RefPtr XRTMemoryManager::Get(ResourceMgr* rm) { - static string* container = new string("XrtState"); - static string* name = new string("MemoryManager"); - XRTMemoryManager* memory_manager = nullptr; - TF_CHECK_OK(rm->LookupOrCreate( - *container, *name, &memory_manager, [](XRTMemoryManager** ret) { - *ret = new XRTMemoryManager(); - return OkStatus(); - })); - return memory_manager; -} - -int64_t XRTMemoryManager::Register(RefPtr tuple) { - DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(), - /*create_if_missing=*/true); - return device_context->Register(std::move(tuple)); -} - -xla::StatusOr> XRTMemoryManager::Lookup( - int64_t handle) { - int device_ordinal = GetDeviceFromHandle(handle); - DeviceContext* device_context = GetDeviceContext(device_ordinal, - /*create_if_missing=*/false); - if (device_context == nullptr) { - return errors::NotFound("XRT memory handle not found: ", handle); - } - RefPtr tuple = device_context->Lookup(handle); - if (tuple == nullptr) { - return errors::NotFound("XRT memory handle not found: ", handle); - } - return std::move(tuple); -} - -Status XRTMemoryManager::Release(int64_t handle) { - int device_ordinal = GetDeviceFromHandle(handle); - DeviceContext* device_context = GetDeviceContext(device_ordinal, - /*create_if_missing=*/false); - if (device_context == nullptr || !device_context->Release(handle)) { - return errors::NotFound("XRT memory handle not found: ", handle); - } - return OkStatus(); -} - -Status XRTMemoryManager::CompactAllocations( - xla::Backend* backend, int device_ordinal, - se::DeviceMemoryAllocator* allocator) { - DeviceContext* device_context = GetDeviceContext(device_ordinal, - /*create_if_missing=*/false); - return device_context != nullptr - ? device_context->CompactAllocations(this, backend, allocator) - : OkStatus(); -} - -void XRTMemoryManager::ReleaseAllAllocations() { - mutex_lock lock(lock_); - for (auto& device_context : device_contexts_) { - if (device_context != nullptr) { - device_context->Clear(); - } - } -} - -xla::StatusOr XRTMemoryManager::Allocate( - xla::Backend* backend, int device_ordinal, size_t size, - se::DeviceMemoryAllocator* allocator) { - auto memory_or = - allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false); - if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) { - VLOG(4) << "Allocate of " << size << " bytes failed on device " - << device_ordinal; - - DeviceContext* device_context = - GetDeviceContext(device_ordinal, - /*create_if_missing=*/false); - if (device_context != nullptr) { - Status status = device_context->TryFreeMemory(backend, size).status(); - if (status.ok()) { - // As long as there is no error, we still try again the allocation, even - // if the TryFreeMemory() call ended up freeing less memory than the - // required size. Fragmentation could make the memory allocation succeed - // even if the freed memory is indeed lower. - memory_or = allocator->Allocate(device_ordinal, size, - /*retry_on_failure=*/false); - } else if (status.code() != error::RESOURCE_EXHAUSTED) { - VLOG(4) << "Allocate of " << size << " bytes on device " - << device_ordinal << ": " << status; - return status; - } - } - } - return memory_or; -} - -string XRTMemoryManager::DebugString() const { - // We might want to emit more detailed information here, like per device - // memory allocations. - return "XRTMemoryManager"; -} - -XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext( - int device_ordinal, bool create_if_missing) { - mutex_lock lock(lock_); - if (device_ordinal >= device_contexts_.size()) { - if (!create_if_missing) { - return nullptr; - } - device_contexts_.resize(device_ordinal + 1); - } - DeviceContext* device_context = device_contexts_[device_ordinal].get(); - if (device_context == nullptr && create_if_missing) { - device_contexts_[device_ordinal] = std::make_unique(); - device_context = device_contexts_[device_ordinal].get(); - } - return device_context; -} - -Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx, - const Status& status) { - DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal, - /*create_if_missing=*/false); - if (device_context == nullptr) { - return status; - } - if (!mrctx->done_freeing) { - // If the caller passed us a zero requested_free_size, we try to free chunks - // of kMaxFreeSize memory, until either the run function succeeds, or we run - // out of freeable memory. - const size_t kMaxFreeSize = 1000000000; - size_t free_size = - (mrctx->requested_free_size > 0) - ? std::min(mrctx->requested_free_size - mrctx->free_size, - kMaxFreeSize) - : kMaxFreeSize; - if (free_size > 0) { - auto free_size_or = - device_context->TryFreeMemory(mrctx->backend, free_size); - if (!free_size_or.ok()) { - return status; - } - size_t size = free_size_or.value(); - mrctx->free_size += size; - if (size > 0) { - return OkStatus(); - } - } - mrctx->done_freeing = true; - } - if (!mrctx->done_compacting) { - mrctx->done_compacting = true; - if (device_context - ->CompactAllocations(this, mrctx->backend, mrctx->allocator) - .ok()) { - return OkStatus(); - } - } - return status; -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_memory_manager.h b/tensorflow/compiler/xrt/xrt_memory_manager.h deleted file mode 100644 index 519938c525a18f..00000000000000 --- a/tensorflow/compiler/xrt/xrt_memory_manager.h +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_ - -#include -#include - -#include "xla/service/backend.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/xrt_refptr.h" -#include "tensorflow/compiler/xrt/xrt_state.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -// The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase -// object which leaves within the ResourceMgr. This is only one XRT memory -// manager object within the ResourceMgr container. -class XRTMemoryManager : public ResourceBase { - // The DeviceContext class, defined and implemented locally inside the - // xrt_memory_manager.cc file, holds, for each device, all the information - // related to the XRT memory management for such device. - class DeviceContext; - - public: - // A working set is a set of tuple allocations which are the input of a given - // operation, and as such they must be pinned on the device memory. The tuple - // allocations added to the WorkingSet will be unpinned at object destruction. - class WorkingSet { - public: - explicit WorkingSet(RefPtr memory_manager); - - ~WorkingSet(); - - // Looks up the tuple handle within the memory manager, and pins it to the - // device (if not already pinned). - Status LookupAndPin(xla::Backend* backend, int64_t handle, - se::DeviceMemoryAllocator* allocator); - - const std::vector>& PinnedTuples() const { - return pinned_tuples_; - } - - const RefPtr& MemoryManager() const { - return memory_manager_; - } - - private: - RefPtr memory_manager_; - std::vector> pinned_tuples_; - }; - - // Retrieves the XRTMemoryManager singleton stored within the ResourceMgr. - static RefPtr Get(ResourceMgr* rm); - - // Registers an XRTTupleAllocation and returns the unique handle identifying - // it. - int64_t Register(RefPtr tuple); - - // Looks up an handle returned by the Register() API and returns the - // XRTTupleAllocation behind it. - xla::StatusOr> Lookup(int64_t handle); - - Status Lookup(int64_t handle, RefPtr* tuple) { - TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle)); - return OkStatus(); - } - - // Releases an handle by dropping the references count held on the - // XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation - // references will continue to be valid. - Status Release(int64_t handle); - - // Tries to compact all the memory allocations on a given device. This is - // currently done by swapping-out all the existing allocation, and swapping - // them back in. - Status CompactAllocations(xla::Backend* backend, int device_ordinal, - se::DeviceMemoryAllocator* allocator); - - // Releases all the device memory allocated by XRT within the resource - // manager. - void ReleaseAllAllocations(); - - // Tries to allocate size bytes of device memory from the device_ordinal - // device. Might attempt to free some unpinned device memory, if the underline - // allocator call fails, and try the allocation again. - xla::StatusOr Allocate( - xla::Backend* backend, int device_ordinal, size_t size, - se::DeviceMemoryAllocator* allocator); - - // Runs the specified function and handling the error::RESOURCE_EXHAUSTED - // status code coming out of it. In such cases, we run different memory - // freeing operations trying to make runfn succeed. The requested_free_size - // argument represents an hint of the requested memory size which would make - // runfn succeed. - template - xla::StatusOr Run(const std::function()>& runfn, - xla::Backend* backend, int device_ordinal, - size_t requested_free_size, - se::DeviceMemoryAllocator* allocator); - - string DebugString() const override; - - // Returns the invalid key value, which will be never generated by the - // Intern() API. - static int64_t InvalidKey() { return 0; } - - private: - // Structure used to track the progress of a try-to-free operation. It is - // initialized and the passed to the TryFreeMemoryStep() API. - struct MemoryReclaimContext { - MemoryReclaimContext(xla::Backend* backend, int device_ordinal, - size_t requested_free_size, - se::DeviceMemoryAllocator* specific_allocator) - : backend(backend), - device_ordinal(device_ordinal), - requested_free_size(requested_free_size) { - allocator = specific_allocator; - } - - xla::Backend* const backend = nullptr; - se::DeviceMemoryAllocator* allocator = nullptr; - const int device_ordinal = 0; - const size_t requested_free_size = 0; - size_t free_size = 0; - bool done_freeing = false; - bool done_compacting = false; - }; - - DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing); - - // Called multiple times while trying to make a memory consuming function call - // to fit. Performs progressively more expensive memory reduction operations, - // until returning error::RESOURCE_EXHAUSTED when no further reductions are - // possible. - Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status); - - mutex lock_; - std::vector> device_contexts_; -}; - -template -xla::StatusOr XRTMemoryManager::Run( - const std::function()>& runfn, xla::Backend* backend, - int device_ordinal, size_t requested_free_size, - se::DeviceMemoryAllocator* allocator) { - MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size, - allocator); - while (true) { - // We assume that runfn is a relatively fast-fail function compared to the - // operations required to free up the required memory. Here we call into the - // TryFreeMemoryStep() API multiple times, which will run progressively more - // expensive operations. - auto result_or = runfn(); - if (result_or.status().code() != error::RESOURCE_EXHAUSTED) { - return result_or; - } - TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status())); - } -} - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_ diff --git a/tensorflow/compiler/xrt/xrt_metrics.cc b/tensorflow/compiler/xrt/xrt_metrics.cc deleted file mode 100644 index e6e4ca8c5fef69..00000000000000 --- a/tensorflow/compiler/xrt/xrt_metrics.cc +++ /dev/null @@ -1,292 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xrt/xrt_metrics.h" - -#include -#include - -#include "tensorflow/core/lib/monitoring/collection_registry.h" -#include "tensorflow/core/platform/regexp.h" - -namespace tensorflow { -namespace { - -static const size_t kMaxSamples = 1024; - -std::vector GetDefaultPercentiles() { - return {25.0, 50.0, 80.0, 90.0, 95.0, 99.0}; -} - -bool IsSelectedMetric(const xrt::XRTMetricsCollect& metrics, - const string& name) { - if (metrics.metrics_regex_size() == 0) { - return true; - } - for (auto& metric_regex : metrics.metrics_regex()) { - if (RE2::FullMatch(name, metric_regex)) { - return true; - } - } - return false; -} - -void SetUnitOfMeasure(xrt::MetricValues* metrics, - monitoring::UnitOfMeasure unit_of_measure) { - switch (unit_of_measure) { - case monitoring::UnitOfMeasure::kNumber: - metrics->set_unit_of_measure(xrt::MetricValues::NUMBER); - break; - case monitoring::UnitOfMeasure::kTime: - metrics->set_unit_of_measure(xrt::MetricValues::TIME); - break; - case monitoring::UnitOfMeasure::kBytes: - metrics->set_unit_of_measure(xrt::MetricValues::BYTES); - break; - } -} - -Status AddMetrics(xrt::MetricsReport* report, - const monitoring::PointSet& point_set) { - for (auto& point : point_set.points) { - xrt::MetricValues* metrics = report->add_metrics(); - metrics->set_name(point_set.metric_name); - if (point->value_type == monitoring::ValueType::kPercentiles) { - xrt::Percentiles* percentiles = metrics->mutable_percentiles_value(); - SetUnitOfMeasure(metrics, point->percentiles_value.unit_of_measure); - percentiles->set_start_nstime(point->percentiles_value.start_nstime); - percentiles->set_end_nstime(point->percentiles_value.end_nstime); - percentiles->set_min_value(point->percentiles_value.min_value); - percentiles->set_max_value(point->percentiles_value.max_value); - percentiles->set_mean(point->percentiles_value.mean); - percentiles->set_stddev(point->percentiles_value.stddev); - percentiles->set_num_samples(point->percentiles_value.num_samples); - percentiles->set_total_samples(point->percentiles_value.total_samples); - percentiles->set_accumulator(point->percentiles_value.accumulator); - for (auto& pct_point : point->percentiles_value.points) { - xrt::Percentiles::Point* xpoint = percentiles->add_points(); - xpoint->set_percentile(pct_point.percentile); - xpoint->set_value(pct_point.value); - } - } else if (point->value_type == monitoring::ValueType::kInt64) { - metrics->set_unit_of_measure(xrt::MetricValues::NUMBER); - metrics->set_int64_value(point->int64_value); - } - } - return OkStatus(); -} - -} // namespace - -namespace xrt_metrics { - -monitoring::PercentileSamplerCell* GetAllocateCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/allocate", "Tracks XRTAllocate times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetAllocateUninitializedCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/allocate_uninitialized", - "Tracks XRTAllocateUninitialized times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetAllocateFromTensorCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/allocate_from_tensor", - "Tracks XRTAllocateFromTensor times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetSubTupleCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/sub_tuple", "Tracks XRTSubTuple times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetMakeTupleCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/make_tuple", "Tracks XRTMakeTuple times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetReadLiteralCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/read_literal", "Tracks XRTReadLiteral times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetReadToTensorCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/read_tensor", "Tracks XRTReadToTensor times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetWriteLiteralCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/write_literal", "Tracks XRTWriteLiteral times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetReleaseAllocationCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/release_allocation", - "Tracks XRTReleaseAllocation times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetReleaseAllAllocationsCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/release_all_allocations", - "Tracks XRTReleaseAllAllocations times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetCompactAllocationsCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/compact_allocations", - "Tracks XRTCompactAllocations times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetCompileCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/compile", "Tracks XRTCompile times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetReleaseCompilationCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/release_compilation", - "Tracks XRTReleaseCompilationRef times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetExecuteCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/execute", "Tracks XRTExecute times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetExecuteChainedCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/ops/execute_chained", - "Tracks XRTExecuteChained times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetMemoryCompactCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/memory_manager/compaction", - "Tracks XRT memory manager memory compaction times"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -monitoring::PercentileSamplerCell* GetTryFreeMemoryCell() { - static monitoring::PercentileSamplerCell* cell = - monitoring::PercentileSampler<0>::New( - {"/tensorflow/xrt/memory_manager/try_free_memory", - "Tracks XRT memory manager times in trying to " - "free memory by swpping device memory to host memory"}, - GetDefaultPercentiles(), kMaxSamples, - monitoring::UnitOfMeasure::kTime) - ->GetCell(); - return cell; -} - -} // namespace xrt_metrics - -xla::StatusOr CollectMetrics( - const xrt::XRTMetricsCollect& metrics) { - auto* collection_registry = monitoring::CollectionRegistry::Default(); - monitoring::CollectionRegistry::CollectMetricsOptions options; - options.collect_metric_descriptors = false; - auto collected_metrics = collection_registry->CollectMetrics(options); - xrt::MetricsReport report; - for (auto& name_pointset : collected_metrics->point_set_map) { - if (IsSelectedMetric(metrics, name_pointset.first)) { - TF_RETURN_IF_ERROR(AddMetrics(&report, *name_pointset.second)); - } - } - return std::move(report); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_metrics.h b/tensorflow/compiler/xrt/xrt_metrics.h deleted file mode 100644 index d6afdbd7e33ab9..00000000000000 --- a/tensorflow/compiler/xrt/xrt_metrics.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_METRICS_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_METRICS_H_ - -#include "xla/statusor.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/lib/monitoring/percentile_sampler.h" - -namespace tensorflow { -namespace xrt_metrics { - -// Defines the singletons of the metrics populated by the XRT op framework. -// Single of a single XRT op there can be many device specific versions (CPU, -// GPU, TPU), and since the monitoring subsystem does not allow multiple -// registrations of the same metric name, we define them all in this file. -monitoring::PercentileSamplerCell* GetAllocateCell(); -monitoring::PercentileSamplerCell* GetAllocateUninitializedCell(); -monitoring::PercentileSamplerCell* GetAllocateFromTensorCell(); -monitoring::PercentileSamplerCell* GetSubTupleCell(); -monitoring::PercentileSamplerCell* GetMakeTupleCell(); -monitoring::PercentileSamplerCell* GetReadLiteralCell(); -monitoring::PercentileSamplerCell* GetReadToTensorCell(); -monitoring::PercentileSamplerCell* GetWriteLiteralCell(); -monitoring::PercentileSamplerCell* GetReleaseAllocationCell(); -monitoring::PercentileSamplerCell* GetReleaseAllAllocationsCell(); -monitoring::PercentileSamplerCell* GetCompactAllocationsCell(); -monitoring::PercentileSamplerCell* GetCompileCell(); -monitoring::PercentileSamplerCell* GetReleaseCompilationCell(); -monitoring::PercentileSamplerCell* GetExecuteCell(); -monitoring::PercentileSamplerCell* GetExecuteChainedCell(); -monitoring::PercentileSamplerCell* GetMemoryCompactCell(); -monitoring::PercentileSamplerCell* GetTryFreeMemoryCell(); - -} // namespace xrt_metrics - -xla::StatusOr CollectMetrics( - const xrt::XRTMetricsCollect& metrics); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_METRICS_H_ diff --git a/tensorflow/compiler/xrt/xrt_refptr.h b/tensorflow/compiler/xrt/xrt_refptr.h deleted file mode 100644 index 2db20dd71ce5ed..00000000000000 --- a/tensorflow/compiler/xrt/xrt_refptr.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Utility functions in support of the XRT API. - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_ - -#include - -namespace tensorflow { - -// Reference counted smart pointer for XRT objects providing the standard -// Ref()/Unref() APIs. -template -class RefPtr { - public: - RefPtr() = default; - // Creates a RefPtr from a pointer. This is an ownership transfer operation, - // and the caller has to own a valid reference to ptr (unless ptr is nullptr). - RefPtr(T* ptr) : ptr_(ptr) {} // NOLINT - RefPtr(const RefPtr& other) : ptr_(other.ptr_) { Acquire(ptr_); } - RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; } - - ~RefPtr() { Release(ptr_); } - - RefPtr& operator=(const RefPtr& other) { - if (this != &other) { - Acquire(other.ptr_); - Release(ptr_); - ptr_ = other.ptr_; - } - return *this; - } - - RefPtr& operator=(RefPtr&& other) { - if (this != &other) { - Release(ptr_); - ptr_ = other.ptr_; - other.ptr_ = nullptr; - } - return *this; - } - - operator bool() const { return ptr_ != nullptr; } // NOLINT - bool operator==(const RefPtr& rhs) const { return ptr_ == rhs.ptr_; } - bool operator!=(const RefPtr& rhs) const { return ptr_ != rhs.ptr_; } - bool operator==(const T* ptr) const { return ptr_ == ptr; } - bool operator!=(const T* ptr) const { return ptr_ != ptr; } - bool operator==(std::nullptr_t ptr) const { return ptr_ == ptr; } - bool operator!=(std::nullptr_t ptr) const { return ptr_ != ptr; } - - T* get() const { return ptr_; } - - T* operator->() const { - CHECK(ptr_ != nullptr); // Crash OK - return ptr_; - } - - T& operator*() const { - CHECK(ptr_ != nullptr); // Crash OK - return *ptr_; - } - - T* release() { - T* ptr = ptr_; - ptr_ = nullptr; - return ptr; - } - - // Resets the RefPtr from a pointer. This is an ownership transfer operation, - // and the caller has to own a valid reference to ptr (unless ptr is nullptr). - void reset(T* ptr = nullptr) { - Release(ptr_); - ptr_ = ptr; - } - - private: - static void Release(T* ptr) { - if (ptr != nullptr) { - ptr->Unref(); - } - } - - static void Acquire(T* ptr) { - if (ptr != nullptr) { - ptr->Ref(); - } - } - - T* ptr_ = nullptr; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_ diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc deleted file mode 100644 index 4189c5e7bc1063..00000000000000 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ /dev/null @@ -1,679 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for allocating XLA literals in device memory and managing handles -// that refer to them. - -#include "tensorflow/compiler/xrt/xrt_state.h" - -#include -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "xla/service/backend.h" -#include "xla/status_macros.h" -#include "tensorflow/compiler/xrt/xrt_memory_manager.h" - -namespace tensorflow { -namespace { - -// Helper typedef to make ShapeTree ForEach helper lambda signatures more -// readable. They need a type of const T& where in this case T is the -// following pointer. -typedef XRTBufferAllocation* XRTBufferAllocationPtr; - -class BufferAllocStats { - public: - struct Stats { - int64_t count = 0; - int64_t size = 0; - }; - - Stats ReportAlloc(int64_t device, int64_t msize) { - mutex_lock lock(lock_); - Stats* device_stats = &stats_[device]; - device_stats->count += 1; - device_stats->size += msize; - return *device_stats; - } - - Stats ReportFree(int64_t device, int64_t msize) { - mutex_lock lock(lock_); - Stats* device_stats = &stats_[device]; - device_stats->count -= 1; - device_stats->size -= msize; - return *device_stats; - } - - private: - mutable mutex lock_; - std::map stats_; -}; - -BufferAllocStats* GetAllocStats() { - static BufferAllocStats* stats = new BufferAllocStats(); - return stats; -} - -Status AllocateScopedShapedBuffer( - XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal, - const xla::Shape& shape, std::unique_ptr* buffer, - se::DeviceMemoryAllocator* allocator) { - auto transfer_manager = backend->transfer_manager(); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); - - // XLA may use a different representation on device than the representation on - // the host. XLA does not document any contract for the relationship between - // these representations :/ Right now, the device shape is always a superset - // of the host shape, meaning that for any valid ShapeIndex in the host shape - // that ShapeIndex is also valid in the device shape, but not vice versa. In - // particular, some host-side types are rewritten to be tuples. We rely on - // this property when making sub-buffers, because we assume that if the client - // requests the host-shape sub-buffer at index i, that will correspond to the - // right device-shape sub-buffer at the same index. - xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape); - VLOG(3) << "Allocating literal buffer: host_shape=" - << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape=" - << xla::ShapeUtil::HumanStringWithLayout(on_device_shape); - - // The ScopedShapedBuffer frees the buffers that have so far been allocated if - // it goes out of scope. That's useful if we return early as the result of an - // error allocating one of the later buffers. - *buffer = std::make_unique( - shape, on_device_shape, allocator, device_ordinal); - for (auto& index_to_buffer : (*buffer)->buffers()) { - const xla::Shape& subshape = - xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); - uint64 size = transfer_manager->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory buffer, - memory_manager->Allocate(backend, device_ordinal, size, allocator)); - // Move our buffer into shaped_buffer, which takes ownership of it. - index_to_buffer.second = buffer.Release(); - VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque() - << " index " << index_to_buffer.first.ToString() << " (" << size - << " bytes)"; - } - - TF_RETURN_IF_ERROR( - transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get()))); - - return OkStatus(); -} - -} // namespace - -XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, - int device_ordinal, - se::DeviceMemoryAllocator* allocator) - : allocation_(allocation), - device_ordinal_(device_ordinal), - allocator_(allocator) { - if (VLOG_IS_ON(2)) { - auto stats = - GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); - LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ - << " count=" << stats.count << " size=" << stats.size; - } -} - -XRTBufferAllocation::~XRTBufferAllocation() { - if (VLOG_IS_ON(2)) { - GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); - } - // Deallocate explicitly allows allocation_ to be null. - TF_CHECK_OK(allocator_->Deallocate(device_ordinal_, allocation_)); - VLOG(2) << "Freed buffer at " << allocation_.opaque() << " (" - << allocation_.size() << " bytes)"; -} - -const se::DeviceMemoryBase& XRTBufferAllocation::allocation() { - return allocation_; -} - -XRTTupleAllocation::XRTTupleAllocation(int device_ordinal, - se::DeviceMemoryAllocator* allocator, - const xla::Shape& on_host_shape, - const xla::Shape& on_device_shape) - : device_ordinal_(device_ordinal), - allocator_(allocator), - on_host_shape_(on_host_shape), - on_device_shape_(on_device_shape), - buffers_(&on_device_shape_), - pin_count_(0) {} - -XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); } - -void XRTTupleAllocation::ReleaseBuffers() { - for (auto& index_buffer : buffers_) { - if (index_buffer.second != nullptr) { - index_buffer.second->Unref(); - index_buffer.second = nullptr; - } - } -} - -/*static*/ Status XRTTupleAllocation::CreateAndTransfer( - const xla::LiteralBase& literal, XRTMemoryManager* memory_manager, - xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator) { - auto transfer_manager = backend->transfer_manager(); - std::unique_ptr scoped_buffer; - TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend, - device_ordinal, literal.shape(), - &scoped_buffer, allocator)); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); - TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - stream.get(), literal, *scoped_buffer)); - - // By releasing the ScopedShapedBuffer we ensure that the underlying storage - // won't be freed when the buffer goes out of scope at the end of this - // call. To avoid a leak, there must be no error-case returns from here until - // the end of the method. - auto shaped_buffer = scoped_buffer->release(); - *allocation = new XRTTupleAllocation(device_ordinal, allocator, - shaped_buffer.on_host_shape(), - shaped_buffer.on_device_shape()); - (*allocation) - ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal); - (*allocation)->SetDeviceMemorySize(); - return OkStatus(); -} - -/*static*/ Status XRTTupleAllocation::CreateUninitialized( - const xla::Shape& shape, XRTMemoryManager* memory_manager, - xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator) { - std::unique_ptr scoped_buffer; - TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend, - device_ordinal, shape, - &scoped_buffer, allocator)); - - // By releasing the ScopedShapedBuffer we ensure that the underlying storage - // won't be freed when the buffer goes out of scope at the end of this - // call. To avoid a leak, there must be no error-case returns from here until - // the end of the method. - auto shaped_buffer = scoped_buffer->release(); - *allocation = new XRTTupleAllocation(device_ordinal, allocator, - shaped_buffer.on_host_shape(), - shaped_buffer.on_device_shape()); - (*allocation) - ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal); - (*allocation)->SetDeviceMemorySize(); - return OkStatus(); -} - -/*static*/ Status XRTTupleAllocation::CreateFromBuffer( - const xla::ShapedBuffer& shaped_buffer, const xla::Shape& on_host_shape, - const xla::Shape& on_device_shape, xla::Backend* backend, - int device_ordinal, XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator) { - *allocation = new XRTTupleAllocation(device_ordinal, allocator, on_host_shape, - on_device_shape); - (*allocation) - ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal); - (*allocation)->SetDeviceMemorySize(); - return OkStatus(); -} - -/*static*/ Status XRTTupleAllocation::CreateFromBuffer( - const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend, - int device_ordinal, XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator) { - return CreateFromBuffer(shaped_buffer, shaped_buffer.on_host_shape(), - shaped_buffer.on_device_shape(), backend, - device_ordinal, allocation, allocator); -} - -Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, - xla::MutableLiteralBase* literal) { - mutex_lock lock(lock_); - return literal_ == nullptr ? StoreToLiteral(backend, literal) - : literal->CopyFrom(*literal_); -} - -Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend, - xla::MutableLiteralBase* literal) { - auto transfer_manager = backend->transfer_manager(); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer()); - return transfer_manager->TransferLiteralFromDevice(stream.get(), - shaped_buffer, literal); -} - -Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, - const xla::Literal& literal) { - if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { - return errors::InvalidArgument( - "New literal shape not matching the existing one: literal=", - xla::ShapeUtil::HumanStringWithLayout(literal.shape()), - " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); - } - mutex_lock lock(lock_); - if (literal_ != nullptr) { - // The allocation is currently swapped out, and we have a host literal for - // its content. Just update the host literal with the new value. - return literal_->CopyFrom(literal); - } - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer()); - auto transfer_manager = backend->transfer_manager(); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); - return transfer_manager->TransferLiteralToDevice(stream.get(), literal, - shaped_buffer); -} - -xla::StatusOr XRTTupleAllocation::SwapOut(xla::Backend* backend, - bool swap_pinned) { - mutex_lock lock(lock_); - if (literal_ == nullptr && (!IsPinned() || swap_pinned)) { - xla::Literal literal(on_host_shape()); - TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal)); - ReleaseBuffers(); - literal_ = std::make_unique(std::move(literal)); - return true; - } - return false; -} - -xla::StatusOr XRTTupleAllocation::SwapIn( - XRTMemoryManager* memory_manager, xla::Backend* backend, - se::DeviceMemoryAllocator* allocator) { - // We need to call AllocateScopedShapedBuffer() outside the locks, since the - // XRTMemoryManager might end up calling back into the SwapOut() API. - // So we do a quick check before using the IsSwapped() API, and it can happen - // that the allocation becomes swapped in after the check. This means which we - // will end up doing an allocation, and then releasing it soon after (via its - // scoped variables). This is an unlikely scenario (two threads calling - // SwapIn() on the same allocation) though. - if (!IsSwapped()) { - return false; - } - - auto transfer_manager = backend->transfer_manager(); - std::unique_ptr scoped_buffer; - TF_RETURN_IF_ERROR( - AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(), - on_host_shape(), &scoped_buffer, allocator)); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); - - mutex_lock lock(lock_); - if (literal_ != nullptr) { - TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - stream.get(), *literal_, *scoped_buffer)); - - auto shaped_buffer = scoped_buffer->release(); - InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal()); - literal_ = nullptr; - return true; - } - return false; -} - -xla::StatusOr XRTTupleAllocation::PinAndSwapIn( - XRTMemoryManager* memory_manager, xla::Backend* backend, - se::DeviceMemoryAllocator* allocator) { - Pin(); - return SwapIn(memory_manager, backend, allocator); -} - -bool XRTTupleAllocation::IsSwapped() const { - mutex_lock lock(lock_); - return literal_ != nullptr; -} - -int64_t XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); } - -int64_t XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); } - -bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; } - -void XRTTupleAllocation::DiscardAllocation( - const xla::ShapeIndex& buffer_index) { - buffers_.element(buffer_index)->DiscardAllocation(); -} - -const xla::Shape& XRTTupleAllocation::on_host_shape() const { - return on_host_shape_; -} - -const xla::Shape& XRTTupleAllocation::on_device_shape() const { - return on_device_shape_; -} - -int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; } - -const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const { - return buffers_.element({})->allocation(); -} - -/*static*/ Status XRTTupleAllocation::MakeSubBuffer( - XRTTupleAllocation* parent, const xla::ShapeIndex& subshape, - XRTTupleAllocation** allocation, bool alias_parent_allocation) { - TF_ASSIGN_OR_RETURN( - const xla::Shape* host_sub_shape, - xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape)); - TF_ASSIGN_OR_RETURN( - const xla::Shape* device_sub_shape, - xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape)); - - *allocation = - new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_, - *host_sub_shape, *device_sub_shape); - if (alias_parent_allocation) { - // Copy the subtree of allocations from the parent allocation. - (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {}); - // Increment the refcount on each aliased buffer. - (*allocation) - ->buffers_.ForEachElement( - [](const xla::ShapeIndex& index, - const XRTBufferAllocationPtr& buffer) { buffer->Ref(); }); - } else { - // Find the buffers in the parent allocation that match the subtree, and - // move the parent allocation's buffer over to the new allocation. - (*allocation) - ->buffers_.ForEachMutableElement( - [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) { - // Extend the allocation's index to the parent's frame by adding - // subshape as a prefix. - xla::ShapeIndex parent_index = subshape; - for (int i = 0; i < index.size(); ++i) { - parent_index.push_back(index[i]); - } - *buffer = parent->buffers_.element(parent_index); - *parent->buffers_.mutable_element(parent_index) = nullptr; - }); - } - (*allocation)->SetDeviceMemorySize(); - return OkStatus(); -} - -void XRTTupleAllocation::SetDeviceMemorySize() { - size_t size = 0; - for (auto& index_buffer : buffers_) { - if (index_buffer.second != nullptr) { - size += index_buffer.second->allocation().size(); - } - } - device_memory_size_ = size; -} - -/* static */ Status XRTTupleAllocation::ExpandTreeOfTuples( - const xla::ShapeTree& elements, int device_ordinal, - se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, - xla::Shape* device_shape) { - // Initialize both host and device shape to be the 'spine' of the new tuple - // shape, given by the shape of the tree of tuples. - *host_shape = elements.shape(); - *device_shape = elements.shape(); - // Now go over the leaves of the tree of tuples, and 'graft' the host/device - // shapes of the allocation at that leaf onto the expanded host/device shapes - // at the leaf position. - TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus( - [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) { - if (elements.IsLeaf(index)) { - if (element.allocation == nullptr) { - return errors::InvalidArgument( - "MakeTuple elements has a null internal node at index ", - index.ToString()); - } - if (device_ordinal != element.allocation->device_ordinal() || - allocator != element.allocation->allocator_) { - return errors::InvalidArgument( - "MakeTuple elements must all be allocated on the same device " - "as the destination."); - } - *xla::ShapeUtil::GetMutableSubshape(host_shape, index) = - element.allocation->on_host_shape(); - *xla::ShapeUtil::GetMutableSubshape(device_shape, index) = - element.allocation->on_device_shape(); - } else { - if (element.allocation != nullptr) { - return errors::InvalidArgument( - "MakeTuple elements has a non-null internal node at index ", - index.ToString()); - } - } - return OkStatus(); - })); - return OkStatus(); -} - -/*static*/ Status XRTTupleAllocation::MakeTuple( - XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal, - const xla::ShapeTree& elements, - XRTTupleAllocation** allocation, se::DeviceMemoryAllocator* allocator) { - auto transfer_manager = backend->transfer_manager(); - TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); - - xla::Shape host_shape; - xla::Shape device_shape; - TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator, - &host_shape, &device_shape)); - - // The aliasing is determined below based on whether or not all the inputs are - // released while being transferred. allocation_tmp is a local pointer that is - // copied to *allocation at the end only if the method succeeds. - XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation( - device_ordinal, allocator, host_shape, device_shape); - core::ScopedUnref allocation_unref(allocation_tmp); - // First allocate device memory for the new tuple index tables, one at each - // internal node of the elements tree. Do this in a separate pass into a - // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if - // an allocation fails. Make sure the shape has layout so that the code that - // writes index tables will be happy lower down. - xla::Shape spine_shape = elements.shape(); - xla::LayoutUtil::SetToDefaultLayout(&spine_shape); - auto new_tuple_buffers = std::make_unique( - spine_shape, spine_shape, allocator, device_ordinal); - TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus( - [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) { - if (!elements.IsLeaf(index)) { - const xla::Shape& subshape = - xla::ShapeUtil::GetSubshape(device_shape, index); - uint64 size = transfer_manager->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer, - memory_manager->Allocate(backend, device_ordinal, - size, allocator)); - VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index " - << index.ToString(); - // Move the new buffer into new_tuple_buffers, which takes ownership - // of it. - new_tuple_buffers->set_buffer(std::move(buffer), index); - } - return OkStatus(); - })); - // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own - // the newly-allocated index tables. Right now there's no owner for the new - // index tables, so next we will transfer ownership to the new allocation, - // taking care not to return early on any errors in the meantime. - xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release(); - // Now fill in the remaining datastructures. After this ForEachElement - // completes: - // 1) Every leaf element of tuple_buffers will be the root buffer of - // an existing allocation, and every internal element of tuple_buffers - // will be a newly-allocated index table. tuple_buffers does not own any - // of these. - // 2) Every element of allocation_tmp->buffers_ will be a correctly - // constructed - // XRTBufferAllocation wrapping the necessary allocations. For buffers in - // existing allocations there will be a new reference owned by the new - // allocation, and for newly-allocated index tables there will be a - // single reference owned by the new allocation. - elements.ForEachElement([&](const xla::ShapeIndex& index, - const ExpandedTupleInput& element) { - if (elements.IsLeaf(index)) { - allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {}, - index); - tuple_buffers.set_buffer(element.allocation->root_allocation(), index); - if (element.release_allocation_after_use) { - // Transfer the references from element's buffers to the new allocation - // rather than incrementing the refcount. The caller should have - // validated that release_allocation_after_use is false if - // element.allocation appears in more than one leaf. - element.allocation->buffers_.ForEachMutableElement( - [&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) { - *buffer = nullptr; - }); - } else { - // Increment the refcount on each newly-aliased buffer. - element.allocation->buffers_.ForEachElement( - [](const xla::ShapeIndex& index, - const XRTBufferAllocationPtr& buffer) { buffer->Ref(); }); - } - } else { - // This is an internal node of the tuple tree so take ownership of the - // newly-created index table. - *allocation_tmp->buffers_.mutable_element(index) = - new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal, - allocator); - } - }); - allocation_tmp->SetDeviceMemorySize(); - // Because the internal nodes of tuple_buffers are exactly the new index - // tables, WriteTupleIndexTables will write only the new index tables and not - // rewrite the index tables for the existing allocations. - TF_RETURN_IF_ERROR( - transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers)); - - *allocation = allocation_tmp; - // Get another reference since allocation_tmp will be Unrefed automatically on - // exit. - (*allocation)->Ref(); - return OkStatus(); -} - -bool XRTTupleAllocation::IsExclusiveOwner() const { - for (const auto& index_buffer : buffers_) { - if (index_buffer.second != nullptr && - !index_buffer.second->RefCountIsOne()) { - return false; - } - } - return true; -} - -size_t XRTTupleAllocation::GetDeviceMemorySize() const { - return device_memory_size_; -} - -void XRTTupleAllocation::InitializeFromShapedBuffer( - const xla::ShapedBuffer& shaped_buffer, - se::DeviceMemoryAllocator* allocator, int device_ordinal) { - for (auto& index_buffer : buffers_) { - if (index_buffer.second != nullptr) { - index_buffer.second->Unref(); - } - // Make a reference-counted version of the allocated buffer. - index_buffer.second = new XRTBufferAllocation( - shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator); - } -} - -xla::StatusOr XRTTupleAllocation::ToShapedBuffer() { - xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(), - device_ordinal_); - for (const auto& index_buffer : buffers_) { - if (index_buffer.second == nullptr || - (index_buffer.second->allocation().is_null() && - index_buffer.second->allocation().size() > 0)) { - return errors::InvalidArgument("Literal buffer at index ", - index_buffer.first.ToString(), - " has been released"); - } - shaped_buffer.set_buffer(index_buffer.second->allocation(), - index_buffer.first); - } - return std::move(shaped_buffer); -} - -Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, - const xla::ShapeIndex& source_index, - const xla::ShapeIndex& dest_index) { - XRTBufferAllocation* source_buffer = source.buffers_.element(source_index); - XRTBufferAllocation* dest_buffer = buffers_.element(dest_index); - if (dest_buffer != nullptr) { - // We allow the destination size being zero, because there are cases where - // we are coming in later filling in null/uninitialized device buffers. In - // all other cases, the size of the new buffer must match. - if (source_buffer->allocation().size() != - dest_buffer->allocation().size() && - dest_buffer->allocation().size() != 0) { - return errors::InvalidArgument( - "Source buffer at index ", source_index.ToString(), - " does not match the size of destination buffer at index ", - dest_index.ToString(), ": ", source_buffer->allocation().size(), - " vs ", dest_buffer->allocation().size()); - } - } else { - const xla::Shape& source_subshape = - xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index); - const xla::Shape& dest_subshape = - xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index); - if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) { - return errors::InvalidArgument( - "Source and destination subshapes do not match: source=", - xla::ShapeUtil::HumanStringWithLayout(source_subshape), - " dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape)); - } - } - *buffers_.mutable_element(dest_index) = source_buffer; - source_buffer->Ref(); - if (dest_buffer != nullptr) { - // If we handed over the ownership of a buffer in ToExecutionInput(), we - // will be called here on the way back from execution, to alias back the - // buffer at that index. In that case the buffers will be the same. So we - // need to discard the memory at the destination buffer, before releasing - // the reference. - if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) && - dest_buffer != source_buffer) { - dest_buffer->DiscardAllocation(); - } - dest_buffer->Unref(); - } - return OkStatus(); -} - -xla::StatusOr XRTTupleAllocation::ToExecutionInput( - const std::function(const xla::ShapeIndex&)>& - alias_checker) { - xla::ExecutionInput result(on_device_shape(), on_host_shape()); - for (const auto& index_buffer : buffers_) { - if (index_buffer.second == nullptr || - (index_buffer.second->allocation().is_null() && - index_buffer.second->allocation().size() > 0)) { - return errors::InvalidArgument("Literal buffer at index ", - index_buffer.first.ToString(), - " has been released"); - } - TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first)); - if (!should_alias) { - result.SetBuffer( - index_buffer.first, - xla::MaybeOwningDeviceMemory(index_buffer.second->allocation())); - } else { - // We keep the ownership of the device memory here. - result.SetUnownedBuffer( - index_buffer.first, - xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory( - index_buffer.second->allocation(), device_ordinal_, allocator_))); - } - } - return std::move(result); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h deleted file mode 100644 index 679071f27eb199..00000000000000 --- a/tensorflow/compiler/xrt/xrt_state.h +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for keeping track of on-device state. - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ - -#include -#include -#include -#include -#include - -#include "xla/literal.h" -#include "xla/service/backend.h" -#include "xla/service/executable.h" -#include "xla/service/shaped_buffer.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/xrt_refptr.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -// Cannot include xrt_memory_manager.h here, as it needs to include this file. -class XRTMemoryManager; - -// TODO(misard) make this a Tensor if and when that makes sense. -// A reference-counted wrapper around a buffer allocation. This maps an XLA -// tuple index or a non-tuple XLA shape to a region of device memory. The device -// memory buffer is freed when the reference count drops to zero. -class XRTBufferAllocation : public core::RefCounted { - public: - XRTBufferAllocation(const se::DeviceMemoryBase& allocation, - int device_ordinal, se::DeviceMemoryAllocator* allocator); - ~XRTBufferAllocation() override; - - // The region of device memory being wrapped. - const se::DeviceMemoryBase& allocation(); - - void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); } - - private: - se::DeviceMemoryBase allocation_; - int device_ordinal_; - se::DeviceMemoryAllocator* allocator_; -}; - -// A XRTTupleAllocation represents an allocated memory area on the device. -// New tuples can be created in three ways: by passing a literal in which case -// device memory is allocated and the literal is transferred to that memory; by -// aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a -// vector of existing handles to create a new tuple. The underlying storage is -// reference-counted. When a handle is released, the reference count of each -// storage buffer is decremented, and buffers with no outstanding references are -// freed. -class XRTTupleAllocation : public core::RefCounted { - public: - ~XRTTupleAllocation() override; - - // Allocates new device memory buffers sufficient to store literal, transfers - // literal to that memory, and returns a XRTTupleAllocation handle to the - // allocated buffers. - static Status CreateAndTransfer(const xla::LiteralBase& literal, - XRTMemoryManager* memory_manager, - xla::Backend* backend, int device_ordinal, - XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator); - - // Allocates new device memory buffers sufficient to store a tensor of - // the specified shape, and returns a XRTTupleAllocation handle to the - // allocated buffers. The allocated buffers are not initialized. - static Status CreateUninitialized(const xla::Shape& shape, - XRTMemoryManager* memory_manager, - xla::Backend* backend, int device_ordinal, - XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator); - - // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle. - static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, - xla::Backend* backend, int device_ordinal, - XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator); - - // Same as the CreateFromBuffer() API above, but with the shapes being passed - // as input. This API is used when creating tuple allocations with the output - // of XLA computations which emit dynamic shaped output via the output shape - // table. - static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, - const xla::Shape& on_host_shape, - const xla::Shape& on_device_shape, - xla::Backend* backend, int device_ordinal, - XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator); - - // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle - // to the sub-shape. If alias_base_allocation is true, the buffers in the - // sub-shape will be shared between parent and the returned allocation, - // otherwise the overlapping buffers in parent will be replaced by - // nullptr. - static Status MakeSubBuffer(XRTTupleAllocation* parent, - const xla::ShapeIndex& subshape, - XRTTupleAllocation** allocation, - bool alias_parent_allocation); - - // A structure describing a leaf of a tree of tuples to expand. Each leaf - // contains an allocation and indicates whether or not the allocation's handle - // should be freed after incorporating its buffers into the expanded tree. - struct ExpandedTupleInput { - RefPtr allocation; - bool release_allocation_after_use; - }; - - // Returns a handle to a new tuple where the subtree of the new tuple at an - // index corresponding to a leaf of 'elements' is constructed from the - // allocation (i.e., a tuple or array) pointed to by that leaf. If - // release_allocation_after_use is false at a leaf, the new tuple will alias - // the input allocation at that leaf, otherwise the input allocation will be - // released. Input allocations may be repeated (appear in more than one leaf) - // in which case the corresponding buffers in the output tuple will alias. If - // an input is repeated, release_input_handle must be false for every leaf - // where that input appears. The latter property is not validated by MakeTuple - // and must be enforced by the caller. - static Status MakeTuple(XRTMemoryManager* memory_manager, - xla::Backend* backend, int device_ordinal, - const xla::ShapeTree& elements, - XRTTupleAllocation** allocation, - se::DeviceMemoryAllocator* allocator); - - // Copies the allocation from device to host and returns it in literal. - Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal); - - // Write a new literal value to the allocation. - Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); - - // Stores the content of the tuple allocation into the internal literal, and - // releases all the device buffers. The swap_pinned flag tells whether a - // pinned allocation should be swapped out. It should be false on all cases, - // but during the memory compaction operation from the XRTMemoryManager. - // Returns a boolean telling whether the allocation was swapped out. - xla::StatusOr SwapOut(xla::Backend* backend, bool swap_pinned); - - // Allocates the device memory required to store the tuple value held within - // the internal literal, and transfer the literal value into the device - // memory. Returns a boolean telling whether the allocation was swapped in. - xla::StatusOr SwapIn(XRTMemoryManager* memory_manager, - xla::Backend* backend, - se::DeviceMemoryAllocator* allocator); - - // Pins the allocation first, then swap it in (if it is not already). After - // this API returns, the allocation is pinned and its content on device - // memory. The caller is responsible for releasing the pin-count using the - // Unpin() API. - xla::StatusOr PinAndSwapIn(XRTMemoryManager* memory_manager, - xla::Backend* backend, - se::DeviceMemoryAllocator* allocator); - - // Checks whether the allocation is currently swapped out. - bool IsSwapped() const; - - // Increases the pin-count of this allocation. If the pin-count is greater - // than 0, the allocation cannot be swapped. Returned the pin-count value - // before the increase. - int64_t Pin(); - - // Decreases the pin-count of this allocation. Returned the pin-count value - // before the decrease. - int64_t Unpin(); - - // Checks whether the allocation is currently pinned. - bool IsPinned() const; - - // True if none of the buffers in the allocation are aliased by any other live - // handle. - bool IsExclusiveOwner() const; - - // Retrieves the footprint in terms of device memory, of this allocation. - size_t GetDeviceMemorySize() const; - - // The ordinal of the device holding this tuple. - int device_ordinal() const; - - // Returns the shape of the tuple as seen by the host. - const xla::Shape& on_host_shape() const; - - // Returns the shape of the tuple as stored on the device. - const xla::Shape& on_device_shape() const; - - // Returns the buffer pointed to by the root of the tuple. - const se::DeviceMemoryBase& root_allocation() const; - - // Stops managing the storage for the allocation at buffer_index, e.g., - // because it has been aliased to the output buffer of a computation. - void DiscardAllocation(const xla::ShapeIndex& buffer_index); - - // Returns the tree of allocations as a ShapedBuffer. This tree may not have - // the same shape as on_host_shape. - xla::StatusOr ToShapedBuffer(); - - // Aliases the source buffer at source_index into the current tuple allocation - // dest_index. - Status AliasBufferFrom(const XRTTupleAllocation& source, - const xla::ShapeIndex& source_index, - const xla::ShapeIndex& dest_index); - - // Returns the device memory tree of this allocation. If the alias_checker - // function returns true for a given index, an owned device memory is returned - // to the caller. But the tuple allocation cannot release the ownership in - // full, as the execute operation might fail. So we rely on a call to - // AliasBufferFrom() to re-alias back the buffers. This is not great (to say - // the least), but the current aliasing logic relies on - // MaybeOwningDeviceMemory being owned, to detect the fact that the user may - // want to alias a buffer. Unfortunately to do that, it needs to release the - // ownership, which is a problem if the execute will fail. - // This calls for a refactoring of the whole owning/maybe-owning interface to - // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr). - // We'd need something similar to XRTTupleAllocation instead of - // ScopedShapedBuffer, which wants ownership and does not allow sharing. - xla::StatusOr ToExecutionInput( - const std::function(const xla::ShapeIndex&)>& - alias_checker); - - private: - // Creates a new handle with (tuple) shape. - XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator, - const xla::Shape& on_host_shape, - const xla::Shape& on_device_shape); - - // Inherits the allocations represented in buffer, which must have the same - // shape as buffers_. - void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer, - se::DeviceMemoryAllocator* allocator, - int device_ordinal); - - // Releases all the XRTBufferAllocation buffer references and set the - // corresponding shape tree entry to nullptr. - void ReleaseBuffers(); - - // Stores the content of the allocation from device memory to the target host - // literal. - Status StoreToLiteral(xla::Backend* backend, - xla::MutableLiteralBase* literal); - - // Sets the total size of the buffers held within this allocation buffers. - // This API should be called once when an XRTTupleAllocation object is - // created, as the XRTTupleAllocation shapes never change, and hence the - // device memory size. - void SetDeviceMemorySize(); - - // Takes a tree 'elements' where each leaf is an allocation, validates that - // they are all on device_ordinal managed by allocator, and returns in - // host_shape and device_shape the host/device shapes of the expanded tree, - // where at each leaf of elements the shape of the allocation at elements is - // grafted on. - static Status ExpandTreeOfTuples( - const xla::ShapeTree& elements, int device_ordinal, - se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, - xla::Shape* device_shape); - - // The lock which protects the internal operations of the tuple allocation. Is - // mutable to allow const-like operations to be declared as such. - mutable mutex lock_; - - // Location of the memory that is being managed. - const int device_ordinal_; - se::DeviceMemoryAllocator* const allocator_; - - // The shape that the caller thinks the tuple has. - const xla::Shape on_host_shape_; - // The shape that the tuple has on device. Store this explicitly instead of - // using a shape stored in ShapeTree because ShapeTree discards the layout. - const xla::Shape on_device_shape_; - // The tree of reference-counted buffers, which uses on_device_shape_ as its - // shape. - xla::ShapeTree buffers_; - // The footprint of the allocation, when residing on device memory. - size_t device_memory_size_ = 0; - // If the allocation is swapped out, this is the literal storing its content. - std::unique_ptr literal_; - // A pinned allocation is one which cannot be swapped out. If pin_count_ > 0 - // then the allocation is pinned. - std::atomic pin_count_; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ diff --git a/tensorflow/compiler/xrt/xrt_tpu_device.cc b/tensorflow/compiler/xrt/xrt_tpu_device.cc deleted file mode 100644 index b747c5505e7aa1..00000000000000 --- a/tensorflow/compiler/xrt/xrt_tpu_device.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xrt/xrt_tpu_device.h" - -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/tpu/tpu_configuration.h" - -namespace tensorflow { - -/*static*/ Status XRTTpuDeviceAccessor::GetResourceManager(OpKernelContext* ctx, - ResourceMgr** rm) { - // ctx is unused here, but maintained because XRTGenericDeviceAccessor uses - // it in its GetResourceManager. - *rm = GetTPUConfigResourceMgr(); - if (*rm == nullptr) { - return errors::Internal("No Tpu resource manager."); - } - return OkStatus(); -} - -Status XRTTpuDeviceAccessor::ScopedRef::Acquire(int device_ordinal) { - TF_ASSIGN_OR_RETURN(node_context_, - tpu::TpuNodeContext::Create(device_ordinal)); - ordinal_ = device_ordinal; - return OkStatus(); -} - -Status XRTTpuDeviceAccessor::ScopedRef::Acquire(OpKernelContext* ctx) { - const XlaDevice::Metadata* metadata; - TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); - return Acquire(metadata->device_ordinal()); -} - -/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef( - OpKernelContext* /*unused ctx*/, int device_ordinal, - ScopedRef* scoped_ref) { - return scoped_ref->Acquire(device_ordinal); -} - -/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(OpKernelContext* ctx, - ScopedRef* scoped_ref) { - return scoped_ref->Acquire(ctx); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_tpu_device.h b/tensorflow/compiler/xrt/xrt_tpu_device.h deleted file mode 100644 index c2251e76be8f42..00000000000000 --- a/tensorflow/compiler/xrt/xrt_tpu_device.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Classes for keeping track of on-device state for TPUs. - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ - -#include - -#include "xla/client/local_client.h" -#include "xla/stream_executor/tpu/tpu_node_context.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" - -namespace tensorflow { - -// This accessor is used for XLA TPU. It uses the distributed TPU compilation -// cache infrastructure which it accesses via the TPU_SYSTEM resource manager. -class XRTTpuDeviceAccessor { - public: - static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); - - class ScopedRef { - public: - ScopedRef() = default; - ~ScopedRef() = default; - - ScopedRef(const ScopedRef&) = delete; - ScopedRef& operator=(const ScopedRef&) = delete; - - // Returns the XLA device properties from the TpuNodeContext object - // protected by this ScopedRef. - xla::Backend* backend() { return node_context_->backend(); } - int device_ordinal() { return ordinal_; } - se::DeviceMemoryAllocator* allocator() { - return backend()->memory_allocator(); - } - - private: - // XRTTpuDeviceAccessor::InitScopedRef is the only way to initialize - // ScopedRef. - friend class XRTTpuDeviceAccessor; - - Status Acquire(int device_ordinal); - - Status Acquire(OpKernelContext* ctx); - - std::unique_ptr node_context_; - int ordinal_ = 0; - }; - - static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, - ScopedRef* scoped_ref); - - static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc deleted file mode 100644 index 5f1df1ff6dc0eb..00000000000000 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ /dev/null @@ -1,450 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xrt/xrt_util.h" - -#include -#include - -#include -#include -#include -#include - -#include "xla/debug_options_flags.h" -#include "xla/types.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace { - -mutex nccl_factory_mutex(LINKER_INITIALIZED); -std::shared_ptr* nccl_factory; - -// The ScopedHandles data structure is used in the ExecuteChained() API and its -// task is to track tuple allocation registrations. It is used both the track -// intermediate results of a chained computation, or its final results. Anything -// which is marked to be released, will be released using the XRTMemoryManager -// once the object is destroyed (unless an explicit call to Drop() or Release() -// is made). -class ScopedHandles { - public: - explicit ScopedHandles(RefPtr memory_manager) - : memory_manager_(std::move(memory_manager)) {} - - ~ScopedHandles() { - for (size_t i = 0; i < handles_.size(); ++i) { - if (handles_release_[i]) { - memory_manager_->Release(handles_[i]).IgnoreError(); - } - } - } - - int64_t operator[](size_t index) const { return handles_.at(index); } - - size_t size() const { return handles_.size(); } - - // Adds the given handle at the index position, by marking it releasable - // according to the release argument. If an existing, and to-be-released - // handle already exists at the same index, it will be released. - Status Add(size_t index, int64_t handle, bool release) { - if (index >= handles_.size()) { - handles_.resize(index + 1, XRTMemoryManager::InvalidKey()); - handles_release_.resize(index + 1, false); - } - if (handles_release_[index]) { - Status status = memory_manager_->Release(handles_[index]); - if (!status.ok()) { - if (release) { - memory_manager_->Release(handle).IgnoreError(); - } - return status; - } - } - handles_[index] = handle; - handles_release_[index] = release; - return OkStatus(); - } - - // Adds a to-be-released tuple allocation at the given index. - Status Add(size_t index, RefPtr tuple) { - return Add(index, memory_manager_->Register(std::move(tuple)), - /*release=*/true); - } - - // Drops the handle at the given index, and releases it using the - // XRTMemoryManager::Release() if marked as to-be-released. - Status Drop(size_t index) { - if (handles_release_.at(index)) { - TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index])); - } - Release(index); - return OkStatus(); - } - - // Releases the handle at the given index. The destructor will not use that - // XRTMemoryManager::Release() API on such handle. - int64_t Release(size_t index) { - int64_t handle = handles_.at(index); - handles_[index] = XRTMemoryManager::InvalidKey(); - handles_release_[index] = false; - return handle; - } - - // Looks up the handle stored at the given index, and returns the matching - // tuple allocation. - xla::StatusOr> Lookup(size_t index) const { - return memory_manager_->Lookup(handles_.at(index)); - } - - private: - RefPtr memory_manager_; - std::vector handles_; - std::vector handles_release_; -}; - -bool DebugOptionsPassThroughEnabled() { - const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH"); - bool enabled = - env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); - if (enabled) { - LOG(WARNING) << "Passing through XLA debug options!"; - } else { - LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options " - "will be retained"; - } - return enabled; -} - -string SafeDebugPath(const string& path) { - if (path.empty() || path.compare(0, 5, "gs://") == 0 || - path.compare(0, 11, "bigstore://") == 0) { - return path; - } - LOG(WARNING) << "Invalid config path (will be dropped): " << path; - return string(); -} - -Status MakeOutput(const RefPtr& output, int64_t index, - RefPtr* result) { - if (index == 0) { - *result = output; - } else { - XRTTupleAllocation* tuple; - TF_RETURN_IF_ERROR( - XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple, - /*alias_parent_allocation=*/true)); - result->reset(tuple); - } - return OkStatus(); -} - -Status PopulateOpWorkingSet(xla::Backend* backend, - const xrt::XRTChainedExecuteOp& op, - int current_index, const ScopedHandles& outputs, - XRTMemoryManager::WorkingSet* working_set, - se::DeviceMemoryAllocator* allocator) { - for (int i = 0; i < op.inputs_size(); ++i) { - auto& input = op.inputs(i); - if (input.op_index() >= current_index) { - return errors::InvalidArgument( - "Input index ", input.op_index(), - " is above the current position: ", current_index); - } - TF_RETURN_IF_ERROR(working_set->LookupAndPin( - backend, outputs[input.op_index()], allocator)); - } - return OkStatus(); -} - -} // namespace - -void SetNcclUniqueIdFactory(std::shared_ptr factory) { - mutex_lock lock(nccl_factory_mutex); - if (nccl_factory == nullptr) { - nccl_factory = new std::shared_ptr(); - } - *nccl_factory = std::move(factory); -} - -std::shared_ptr GetNcclUniqueIdFactory() { - mutex_lock lock(nccl_factory_mutex); - return nccl_factory != nullptr ? *nccl_factory : nullptr; -} - -xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { - static const bool options_passthrough = DebugOptionsPassThroughEnabled(); - if (options_passthrough) { - return ref_options; - } - xla::DebugOptions options = xla::GetDebugOptionsFromFlags(); - options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to())); - options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto()); - options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text()); - options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots()); - options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re()); - options.set_xla_dump_include_timestamp( - ref_options.xla_dump_include_timestamp()); - options.set_xla_dump_max_hlo_modules(ref_options.xla_dump_max_hlo_modules()); - options.set_xla_dump_enable_mlir_pretty_form( - ref_options.xla_dump_enable_mlir_pretty_form()); - - for (auto& pass : ref_options.xla_disable_hlo_passes()) { - options.add_xla_disable_hlo_passes(pass); - } - return options; -} - -xla::StatusOr> GetComputationInputs( - OpKernelContext* context, const char* input_name) { - OpInputList arg_list; - TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list)); - // Concatenate all input uids from list of scalars-or-vectors carrying them. - std::vector input_coords; - for (int i = 0; i < arg_list.size(); ++i) { - const Tensor& arg = arg_list[i]; - if (TensorShapeUtils::IsScalar(arg.shape())) { - input_coords.emplace_back(arg.scalar()()); - } else { - TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape())); - auto arg_vec = arg.vec(); - const int64_t num_elts = arg.shape().dim_size(0); - for (int i = 0; i < num_elts; ++i) { - input_coords.emplace_back(arg_vec(i)); - } - } - } - return std::move(input_coords); -} - -bool InputShapeMatches(const xla::Shape& parameter_shape, - const xla::Shape& input_shape) { - auto shape_checker = [&](const xla::Shape& pshape, - const xla::ShapeIndex& index) { - if (pshape.IsArray()) { - TF_ASSIGN_OR_RETURN(const xla::Shape* ishape, - xla::ShapeUtil::TryGetSubshape(input_shape, index)); - if (pshape.rank() != ishape->rank() || - pshape.element_type() != ishape->element_type()) { - return errors::InvalidArgument("Mismatching shapes"); - } - if (pshape.is_static() && !xla::Layout::Equal().IgnoreTiles()( - pshape.layout(), ishape->layout())) { - return errors::InvalidArgument("Mismatching layouts"); - } - for (int64_t dim = 0; dim < pshape.rank(); ++dim) { - if (pshape.is_dynamic_dimension(dim)) { - if (pshape.dimensions(dim) < ishape->dimensions(dim)) { - return errors::InvalidArgument("Mismatching shapes"); - } - } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) { - return errors::InvalidArgument("Mismatching shapes"); - } - } - } - return OkStatus(); - }; - return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape, - shape_checker) - .ok(); -} - -xla::StatusOr>> GetInputTupleAllocations( - const std::vector& input_coords, - XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, - int64_t num_input_shapes, - const std::function& shape_getter, bool release_inputs, - se::DeviceMemoryAllocator* allocator) { - if (input_coords.size() != num_input_shapes) { - return errors::InvalidArgument( - "Number of inputs does not match executable proto input shapes: ", - input_coords.size(), " vs. ", num_input_shapes); - } - std::vector> input_tuples; - input_tuples.reserve(input_coords.size()); - for (size_t i = 0; i < input_coords.size(); ++i) { - TF_RETURN_IF_ERROR( - working_set->LookupAndPin(backend, input_coords[i].handle, allocator)); - auto tuple = working_set->PinnedTuples().back(); - if (release_inputs) { - // We are holding a reference to the tuple, so we can safely delete it - // from the resource manager here. - TF_RETURN_IF_ERROR( - working_set->MemoryManager()->Release(input_coords[i].handle)); - VLOG(2) << "Released allocation handle " << input_coords[i].handle; - } - xla::Shape input_shape = shape_getter(i); - if (!InputShapeMatches(input_shape, tuple->on_host_shape())) { - return errors::InvalidArgument( - "Run-time shape mismatch for XRTExecute argument[", i, "] (", - input_coords[i].handle, "). Expected ", input_shape.DebugString(), - "; got ", tuple->on_host_shape().DebugString()); - } - if (input_coords[i].index.empty()) { - input_tuples.emplace_back(std::move(tuple)); - } else { - XRTTupleAllocation* sub_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - tuple.get(), input_coords[i].index, &sub_tuple, - /*alias_parent_allocation=*/true)); - input_tuples.emplace_back(sub_tuple); - } - } - return std::move(input_tuples); -} - -Status RebuildOutputAliases( - const RefPtr& output_tuple, - absl::Span> input_tuples, - const xla::HloInputOutputAliasConfig& input_output_alias) { - auto alias_function = - [&](const xla::ShapeIndex& output_index, - const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { - TF_RET_CHECK(alias.parameter_number < input_tuples.size()); - return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number], - alias.parameter_index, output_index); - }; - return input_output_alias.ForEachAliasWithStatus(alias_function); -} - -xla::StatusOr> GetArgumentsBuffers( - const xla::HloInputOutputAliasConfig& input_output_alias, - absl::Span> input_tuples, - const std::vector& input_is_dynamic, bool release_inputs) { - auto is_dynamic = [&](size_t arg) { - return arg < input_is_dynamic.size() && input_is_dynamic[arg]; - }; - std::vector arguments; - // Don't alias dynamic input -- Due to the underlying implementation, - // aliased inputs have two owners: XRTAllocation and return value of - // this function. If an argument is dynamic and the ownership is - // released to output of this function, TPUExecute will free it and - // reallocate a new one, which creates a double freeing issue where - // XRTAllocation also attempts to release the buffer. - bool alias_outputs = release_inputs && input_tuples.size() == 1 && - input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0); - arguments.reserve(input_tuples.size()); - for (int64_t i = 0; i < input_tuples.size(); ++i) { - auto alias_checker = - [&](const xla::ShapeIndex& index) -> xla::StatusOr { - if (input_output_alias.ParameterHasAlias(i, index)) { - TF_RET_CHECK(!is_dynamic(i)); - return true; - } - return alias_outputs; - }; - TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input, - input_tuples[i]->ToExecutionInput(alias_checker)); - arguments.emplace_back(std::move(exec_input)); - } - return std::move(arguments); -} - -Status CreateExecuteOutput(OpKernelContext* context, - XRTMemoryManager* memory_manager, - RefPtr output_tuple, - bool return_exploded_tuple) { - if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) { - int64_t tuple_element_count = - xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); - Tensor* output_tensor; - TF_RETURN_IF_ERROR(context->allocate_output( - 0, TensorShape({tuple_element_count}), &output_tensor)); - - for (int64_t i = 0; i < tuple_element_count; ++i) { - XRTTupleAllocation* suballocation; - TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - output_tuple.get(), {i}, &suballocation, - /*alias_parent_allocation=*/false)); - output_tensor->vec()(i) = - memory_manager->Register(suballocation); - } - } else { - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - output_tensor->scalar()() = - memory_manager->Register(std::move(output_tuple)); - } - return OkStatus(); -} - -Status ExecuteChained(OpKernelContext* context, - const RefPtr& memory_manager, - xla::Backend* backend, int device_ordinal, - const xrt::XRTChainedExecutePlan& plan, - const xrt::XRTChainedExecuteConfig& config, - const ChainedExecuteFn& execute_op, - se::DeviceMemoryAllocator* allocator) { - // Create the vector which tracks the uses of the intermediate chained - // operations outputs. - std::vector uses(plan.ops_size(), 0); - for (auto& op : plan.ops()) { - for (auto& input : op.inputs()) { - uses[input.op_index()] += 1; - } - } - - ScopedHandles outputs(memory_manager); - ScopedHandles results(memory_manager); - for (int i = 0; i < plan.ops_size(); ++i) { - auto& op = plan.ops(i); - if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) { - // This operation is a device data load. Set the handle as output and - // leave the release flag off, since this is not an intermediate output. - TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false)); - } else if (op.op_oneof_case() == - xrt::XRTChainedExecuteOp::kComputationHandle) { - // This is an XRT execute operation, forward to the device specific - // handler. Populating the working set makes sure the input allocations - // for this execute operations are pinned to device memory. - XRTMemoryManager::WorkingSet working_set(memory_manager); - TF_RETURN_IF_ERROR(PopulateOpWorkingSet(backend, op, i, outputs, - &working_set, allocator)); - TF_ASSIGN_OR_RETURN(auto tuple, - execute_op(op, working_set.PinnedTuples())); - TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple))); - } else { - return errors::InvalidArgument( - "Undefined operation kind at post-order position ", i); - } - // If the result of this chained operation is an output result, feed the - // results at the desired position. - for (auto& output : op.outputs()) { - TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i)); - RefPtr result; - TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result)); - TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result))); - } - // Drop intermediate results which have no more users. - for (auto& input : op.inputs()) { - uses[input.op_index()] -= 1; - if (uses[input.op_index()] == 0) { - TF_RETURN_IF_ERROR(outputs.Drop(input.op_index())); - } - } - } - - Tensor* output_tensor; - TF_RETURN_IF_ERROR(context->allocate_output( - 0, TensorShape({static_cast(results.size())}), &output_tensor)); - for (size_t i = 0; i < results.size(); ++i) { - output_tensor->vec()(i) = results.Release(i); - } - return OkStatus(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h deleted file mode 100644 index a9f68d676efa6b..00000000000000 --- a/tensorflow/compiler/xrt/xrt_util.h +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Utility functions in support of the XRT API. - -#ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ -#define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ - -#include -#include -#include -#include -#include - -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/service/backend.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/xla.pb.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/compiler/xrt/xrt_memory_manager.h" -#include "tensorflow/compiler/xrt/xrt_refptr.h" -#include "tensorflow/compiler/xrt/xrt_state.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -// Factory class which creates NCCL unique IDs based on the replicas -// participating to a given communication. This is only used for GPU backends. -struct NcclUniqueIdFactory { - virtual ~NcclUniqueIdFactory() = default; - - // Generates the NCCL unique ID for the given set of replica IDs. - virtual std::string GetUniqueId(absl::Span replicas) = 0; -}; - -void SetNcclUniqueIdFactory(std::shared_ptr factory); - -std::shared_ptr GetNcclUniqueIdFactory(); - -struct InputCoords { - explicit InputCoords(int64_t handle) : handle(handle) {} - InputCoords(int64_t handle, xla::ShapeIndex index) - : handle(handle), index(std::move(index)) {} - - int64_t handle = 0; - xla::ShapeIndex index; -}; - -// Filters the debug options provided as argument according to the value of the -// TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is -// set to "1" or "true", the debug options will be returned as is. Otherwise -// only a subset of them will be set in the returned ones, and all the paths -// contained in it, will be limited to gs:// and bigstore:// ones. -xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); - -// Populates the input_coords with a list of input coordinates from a input_name -// op argument. -xla::StatusOr> GetComputationInputs( - OpKernelContext* context, const char* input_name); - -bool InputShapeMatches(const xla::Shape& parameter_shape, - const xla::Shape& input_shape); - -xla::StatusOr>> GetInputTupleAllocations( - const std::vector& input_coords, - XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, - int64_t num_input_shapes, - const std::function& shape_getter, bool release_inputs, - se::DeviceMemoryAllocator* allocator); - -Status RebuildOutputAliases( - const RefPtr& output_tuple, - absl::Span> input_tuples, - const xla::HloInputOutputAliasConfig& input_output_alias); - -xla::StatusOr> GetArgumentsBuffers( - const xla::HloInputOutputAliasConfig& input_output_alias, - absl::Span> input_tuples, - const std::vector& input_is_dynamic, bool release_inputs); - -// Create the XRT execute output tensor given the computation result -// (output_tuple). The return_exploded_tuple tells whether a tuple result should -// be returned as vector of handles representing each tuple child. -Status CreateExecuteOutput(OpKernelContext* context, - XRTMemoryManager* memory_manager, - RefPtr output_tuple, - bool return_exploded_tuple); - -// Drives the XRT chained computation execution given the supplied core execute -// function. -using ChainedExecuteFn = - std::function>( - const xrt::XRTChainedExecuteOp&, - absl::Span>)>; -Status ExecuteChained(OpKernelContext* context, - const RefPtr& memory_manager, - xla::Backend* backend, int device_ordinal, - const xrt::XRTChainedExecutePlan& plan, - const xrt::XRTChainedExecuteConfig& config, - const ChainedExecuteFn& execute_op, - se::DeviceMemoryAllocator* allocator); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8491b7fbf2f2a2..b90c7cbbb5cc44 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -121,7 +121,6 @@ package( "//tensorflow_models:__subpackages__", ], features = if_google([ - "-layering_check", "-parse_headers", ]), licenses = ["notice"], @@ -1310,6 +1309,7 @@ cc_library( ], hdrs = [":lib_internal_public_headers"], copts = tf_copts(), + features = ["-layering_check"], deps = tf_additional_lib_deps() + [ ":core_stringpiece", ":lib_proto_parsing", @@ -1456,6 +1456,7 @@ cc_library( }) + # The TF proto implementations that we will statically link here. [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc_impl", @@ -1641,6 +1642,7 @@ tf_cuda_library( ], hdrs = [":framework_internal_public_headers"], copts = tf_copts(), + features = ["-layering_check"], linkopts = select({ "//tensorflow:freebsd": ["-lm"], "//tensorflow:windows": [], @@ -1770,7 +1772,11 @@ tf_cuda_library( ":protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt b/tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt new file mode 100644 index 00000000000000..7ec4d4db81f96c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GlobalIterId" + visibility: HIDDEN +} \ No newline at end of file diff --git a/tensorflow/core/api_def/base_api/api_def_ListSnapshotChunksDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ListSnapshotChunksDataset.pbtxt new file mode 100644 index 00000000000000..83bce65aa59919 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ListSnapshotChunksDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ListSnapshotChunksDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt b/tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt new file mode 100644 index 00000000000000..7ec4d4db81f96c --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GlobalIterId" + visibility: HIDDEN +} \ No newline at end of file diff --git a/tensorflow/core/build_defs.bzl b/tensorflow/core/build_defs.bzl index b9952c2214524f..9d948028278ffd 100644 --- a/tensorflow/core/build_defs.bzl +++ b/tensorflow/core/build_defs.bzl @@ -4,12 +4,18 @@ load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary") def _tf_core_transition_impl(settings, attr): _ignore = (settings, attr) # @unused - return {"@local_tsl//tsl/framework/contraction:disable_onednn_contraction_kernel": True} + return { + "@local_tsl//tsl/framework/contraction:disable_onednn_contraction_kernel": True, + "//tensorflow/compiler/mlir/python:disable_mlir": True, + } _tf_core_transition = transition( implementation = _tf_core_transition_impl, inputs = [], - outputs = ["@local_tsl//tsl/framework/contraction:disable_onednn_contraction_kernel"], + outputs = [ + "@local_tsl//tsl/framework/contraction:disable_onednn_contraction_kernel", + "//tensorflow/compiler/mlir/python:disable_mlir", + ], ) def _py_binary_tf_core_impl(ctx): diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index bf5b15eebdc72f..cfec2624420c3f 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -131,14 +131,12 @@ cc_library( srcs = ["collective_test_util.cc"], hdrs = ["collective_test_util.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":device_resolver_local", ":process_util", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:session_options", "//tensorflow/core:testlib", "//tensorflow/core/framework:allocator", "//tensorflow/core/framework:device_attributes_proto_cc", @@ -146,6 +144,8 @@ cc_library( "//tensorflow/core/nccl:collective_communicator", "//tensorflow/core/platform:refcount", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:unbounded_work_queue", + "@com_google_absl//absl/synchronization", ], ) @@ -329,7 +329,6 @@ cc_library( srcs = ["all_to_all.cc"], hdrs = ["all_to_all.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":base_collective_executor", ":collective_rma_local", @@ -341,7 +340,7 @@ cc_library( ":process_util", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/platform:blocking_counter", ], alwayslink = 1, ) @@ -385,13 +384,13 @@ cc_library( srcs = ["buf_rendezvous.cc"], hdrs = ["buf_rendezvous.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":device", ":device_mgr", ":process_util", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -462,12 +461,13 @@ cc_library( srcs = ["collective_param_resolver_local.cc"], hdrs = ["collective_param_resolver_local.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":device_mgr", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -1373,6 +1373,7 @@ cc_library( ":bfc_allocator", ":pool_allocator", "//tensorflow/core:lib", + "//tensorflow/core/util:env_var", "//tensorflow/core/util:onednn_env_vars", ], ) @@ -1388,6 +1389,7 @@ cc_library( deps = [ ":function", ":optimization_registry", + ":process_util", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:graph", diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index edeca472fa9b4a..376b6d81351458 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -1,5 +1,3 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_zendnn", @@ -9,6 +7,8 @@ load( "tf_cuda_library", "tf_mkl_kernel_library", ) +load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//third_party/mkl:build_defs.bzl", "if_mkl", @@ -119,6 +119,8 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", @@ -301,6 +303,8 @@ tf_cuda_library( "//tensorflow/core/platform:platform_port", "//tensorflow/core/util:managed_stack_trace", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", diff --git a/tensorflow/core/common_runtime/eager/attr_builder_test.cc b/tensorflow/core/common_runtime/eager/attr_builder_test.cc index 185acbf9463428..1baf0ddcdceb48 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder_test.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder_test.cc @@ -162,10 +162,12 @@ TEST(AttrBuilder, BuildNodeDef_Modified) { AttrBuilder a("MatMul"); a.Set("transpose_a", true); a.Set("transpose_b", false); + a.Set("grad_x", true); + a.Set("grad_y", false); a.NumInputs(2); const NodeDef& node_def = a.BuildNodeDef(); - EXPECT_EQ(node_def.attr().size(), 2); + EXPECT_EQ(node_def.attr().size(), 6); a.Set("new_attr", 15); a.NumInputs(3); @@ -173,11 +175,15 @@ TEST(AttrBuilder, BuildNodeDef_Modified) { const NodeDef& node_def2 = a.BuildNodeDef(); auto attrs = node_def2.attr(); - EXPECT_EQ(attrs.size(), 3); + EXPECT_EQ(attrs.size(), 7); ASSERT_NE(attrs.find("transpose_a"), attrs.end()); EXPECT_EQ(attrs.find("transpose_a")->second.b(), true); ASSERT_NE(attrs.find("transpose_b"), attrs.end()); EXPECT_EQ(attrs.find("transpose_b")->second.b(), false); + ASSERT_NE(attrs.find("grad_x"), attrs.end()); + EXPECT_EQ(attrs.find("grad_x")->second.b(), true); + ASSERT_NE(attrs.find("grad_y"), attrs.end()); + EXPECT_EQ(attrs.find("grad_y")->second.b(), false); ASSERT_NE(attrs.find("new_attr"), attrs.end()); EXPECT_EQ(attrs.find("new_attr")->second.i(), 15); } diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 1ba7291a3e07be..a7306be3b8b431 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -26,6 +26,8 @@ limitations under the License. // clang-format off // Required for IS_MOBILE_PLATFORM +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -1008,6 +1010,42 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef, return OkStatus(); } +Status EagerContext::AddComponentFunction(const FunctionDef& fdef, + const FunctionDefLibrary& library) { + { + mutex_lock l(cache_mu_); + auto iter = component_function_libraries_.find(fdef.signature().name()); + if (iter == component_function_libraries_.end()) { + // TODO(mrry): For any functions in the main function library, consider + // deduplicating them here. + auto component_func_lib_def = std::make_unique( + OpRegistry::Global(), library); + TF_RETURN_IF_ERROR(component_func_lib_def->AddFunctionDef(fdef, {})); + component_function_libraries_.insert( + {fdef.signature().name(), std::move(component_func_lib_def)}); + } else { + // The function has been registered before. If the function is different, + // we error out. + const FunctionDef* prev_fdef = + iter->second->Find(fdef.signature().name()); + if (prev_fdef == nullptr) { + return absl::InternalError( + absl::StrCat("Component function: ", fdef.signature().name(), + " is in the cache but not in the library")); + } + if (!FunctionDefsEqual(fdef, *prev_fdef)) { + return absl::InvalidArgumentError(absl::StrCat( + "Attempting to add a duplicate function with name: ", + fdef.signature().name(), " where the previous and current ", + "definitions differ. Previous definition: ", + prev_fdef->DebugString(), + " and current definition: ", fdef.DebugString())); + } + } + } + return OkStatus(); +} + const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) { return func_lib_def_.Find(function_name); } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3aa9a5a3d03890..075849fae3304b 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -251,6 +251,14 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { bool add_to_local_only = false, const StackTracesMap& stack_traces = {}); + // Adds a component function (i.e. containing a subgraph of a multi-process + // function) implemented as `fdef`. + // + // REQUIRES: `library` must contain all functions reachable from `fdef`. It + // should not contain `fdef` itself. + Status AddComponentFunction(const FunctionDef& fdef, + const FunctionDefLibrary& library); + const FunctionDef* GetFunctionDef(const string& function_name); std::vector ListFunctionNames() override; @@ -385,6 +393,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { FunctionLibraryDefinition* FuncLibDef() override { return &func_lib_def_; } + FunctionLibraryDefinition* GetComponentFunctionFunctionLibraryDefinition( + const string& function_name) { + tf_shared_lock lock(cache_mu_); + auto iter = component_function_libraries_.find(function_name); + if (iter != component_function_libraries_.end()) { + return iter->second.get(); + } + return nullptr; + } + #if !defined(IS_MOBILE_PLATFORM) // Assign the EagerClient pointer to `client` based on the given device / task // name, and increment the refcount of the client. The reference ownership is @@ -756,6 +774,9 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { kernel_cache_ TF_GUARDED_BY(cache_mu_); std::unordered_map registered_functions_ TF_GUARDED_BY(cache_mu_); + + std::unordered_map> + component_function_libraries_ TF_GUARDED_BY(cache_mu_); absl::flat_hash_map device_cache_ TF_GUARDED_BY(device_cache_mu_); std::unordered_map>> diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 545585750b6abb..58888afece8bd1 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" @@ -27,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" @@ -333,16 +336,24 @@ Status EagerOperation::Reset( if (!is_function) { const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get(); colocation_exempt_ = exempt_ops.find(op) != exempt_ops.end(); - TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_)); - } else if (!remote && !ctx_.FindFunctionByName(op)) { - return errors::NotFound( - "'", op, - "' is neither a type of a primitive operation nor a name " - "of a function registered in binary running on ", - port::Hostname(), - ". Make sure the operation or function is " - "registered in the binary running in this process."); + } else if (!remote) { + const FunctionLibraryDefinition* func_lib_def; + if (eager_func_params.has_value() && + eager_func_params.value().func_lib_def_override != nullptr) { + func_lib_def = eager_func_params.value().func_lib_def_override; + } else { + func_lib_def = ctx_.FuncLibDef(); + } + if (func_lib_def->Find(op) == nullptr) { + return absl::NotFoundError(absl::StrCat( + "'", op, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process.")); + } } attrs_.Reset(op); stack_trace_.reset(); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index ccde391e8dc53d..3ddf91c5ed5f52 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/managed_stack_trace.h" @@ -153,6 +154,23 @@ class EagerOperation : public ImmediateExecutionOperation { tensorflow::EagerContext& EagerContext() const { return ctx_; } + const FunctionLibraryDefinition* FuncLibDef() const { + if (eager_func_params_.has_value() && + eager_func_params_.value().func_lib_def_override) { + return eager_func_params_.value().func_lib_def_override; + } else { + return ctx_.FuncLibDef(); + } + } + + const FunctionDef* GetFunctionDef() const { + if (is_function_) { + return FuncLibDef()->Find(attrs_.op_name()); + } else { + return nullptr; + } + } + AttrBuilder* MutableAttrs() { return &attrs_; } const AttrBuilder& Attrs() const { return attrs_; } diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index daaab604d2a01d..0d68aac0cff554 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -291,8 +291,7 @@ Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) { const auto& node_def = op->MutableAttrs()->BuildNodeDef(); const OpDef* op_def = nullptr; - const FunctionDef* function_def = - op->EagerContext().FuncLibDef()->Find(op->Name()); + const FunctionDef* function_def = op->GetFunctionDef(); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -420,8 +419,7 @@ Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, return OkStatus(); } - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name()); + const FunctionDef* function_def = op->GetFunctionDef(); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", op->Name(), "'"); } @@ -445,8 +443,7 @@ Status HasTPUReplication(const EagerOperation& op, const EagerContext& ctx, return OkStatus(); } - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(op.Name()); + const FunctionDef* function_def = op.GetFunctionDef(); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", op.Name(), "'"); } @@ -513,11 +510,12 @@ Status HasNestedJitCompile(const EagerOperation& op, const EagerContext& ctx, std::queue function_names; function_names.push(op.Name()); + const FunctionLibraryDefinition* func_lib_def = op.FuncLibDef(); + while (!function_names.empty()) { const string& function_name = function_names.front(); - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(function_name); + const FunctionDef* function_def = func_lib_def->Find(function_name); if (function_def == nullptr) { return errors::NotFound("Failed to find function '", function_name, "'"); } @@ -1537,8 +1535,8 @@ Status GetOrCreateKernelAndDevice( ctx.GetCollectiveExecutorHandle(), ctx.HostCPU())); } - TF_RETURN_IF_ERROR( - kernel->Init(ctx.LogDevicePlacement(), ndef, graph_collector)); + TF_RETURN_IF_ERROR(kernel->Init(ctx.LogDevicePlacement(), ndef, + graph_collector, op->eager_func_params())); // Exclude tf.data op kernels from being cached. The reason for this is // that tf.data op kernels that accept a user-defined function will have a @@ -1548,8 +1546,7 @@ Status GetOrCreateKernelAndDevice( // programs that build input pipeline graphs in a loop. const OpDef* op_def; if (op->is_function()) { - const FunctionDef* function_def = - op->EagerContext().FuncLibDef()->Find(op->Name()); + const FunctionDef* function_def = op->GetFunctionDef(); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -1976,8 +1973,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, std::unique_ptr node(new eager::RemoteExecuteNode( &op->EagerContext(), std::move(request), op_device, ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(), - op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), - *inputs, {retvals, num_outputs})); + op->MutableAttrs()->BuildNodeDef(), op->FuncLibDef(), *inputs, + {retvals, num_outputs})); if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { string msg = strings::StrCat( diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 7b3b383b3ddb44..460fab04252ece 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -103,9 +104,14 @@ KernelAndDeviceFunc::~KernelAndDeviceFunc() { } } -Status KernelAndDeviceOp::Init(const bool log_device_placement, - const NodeDef& ndef, - GraphCollector* graph_collector) { +Status KernelAndDeviceOp::Init( + const bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collecto, + const absl::optional& eager_func_params) { + if (eager_func_params.has_value()) { + return absl::InternalError( + "KernelAndDeviceOp does not support EagerFunctionParams."); + } OpKernel* k = nullptr; if (flr_ == nullptr) { return errors::Internal( @@ -141,22 +147,31 @@ Status KernelAndDeviceOp::Init(const bool log_device_placement, return OkStatus(); } -Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, - const NodeDef& ndef, - GraphCollector* graph_collector) { +Status KernelAndDeviceFunc::InstantiateFunc( + const bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) { const OpDef* op_def = nullptr; - const FunctionDef* function_def; - if (flr_ == nullptr) { - // If function is being executed without an explicit device request, - // lookup the FunctionDef in the CPU's FLR. All FLRs share the same - // library. - function_def = pflr_->GetFLR(host_cpu_device_->name()) - ->GetFunctionLibraryDefinition() - ->Find(ndef.op()); + const FunctionLibraryDefinition* func_lib_def; + FunctionLibraryRuntime::InstantiateOptions options; + + if (eager_func_params.has_value() && + eager_func_params.value().func_lib_def_override != nullptr) { + func_lib_def = eager_func_params.value().func_lib_def_override; + options.lib_def = func_lib_def; } else { - function_def = flr_->GetFunctionLibraryDefinition()->Find(ndef.op()); + if (flr_ == nullptr) { + // If function is being executed without an explicit device request, + // lookup the FunctionDef in the CPU's FLR. All FLRs share the same + // library. + func_lib_def = pflr_->GetFLR(host_cpu_device_->name()) + ->GetFunctionLibraryDefinition(); + } else { + func_lib_def = flr_->GetFunctionLibraryDefinition(); + } } + const FunctionDef* function_def = func_lib_def->Find(ndef.op()); if (function_def != nullptr) { op_def = &(function_def->signature()); } else { @@ -165,7 +180,6 @@ Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, TF_RETURN_IF_ERROR( InOutTypesForNode(ndef, *op_def, &input_dtypes_, &output_dtypes_)); - FunctionLibraryRuntime::InstantiateOptions options; options.target = device_ == nullptr ? "" : device_->name(); options.is_multi_device_function = true; for (const Device* device : input_devices_) { @@ -174,13 +188,10 @@ Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, options.composite_devices = composite_devices_; options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_; if (outputs_on_op_device_) { - const FunctionLibraryDefinition* lib_def = - pflr_->GetFunctionLibraryDefinition(); - const FunctionDef* fdef = lib_def->Find(ndef.op()); - if (fdef == nullptr) { + if (function_def == nullptr) { return errors::InvalidArgument("Failed to find function ", ndef.op()); } - for (int i = 0; i < fdef->signature().output_arg_size(); ++i) { + for (int i = 0; i < function_def->signature().output_arg_size(); ++i) { options.output_devices.push_back(options.target); } } @@ -248,11 +259,12 @@ Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement, return pflr_->IsCrossProcess(handle_, &is_cross_process_); } -Status KernelAndDeviceFunc::Init(const bool log_device_placement, - const NodeDef& ndef, - GraphCollector* graph_collector) { - TF_RETURN_IF_ERROR( - InstantiateFunc(log_device_placement, ndef, graph_collector)); +Status KernelAndDeviceFunc::Init( + const bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) { + TF_RETURN_IF_ERROR(InstantiateFunc(log_device_placement, ndef, + graph_collector, eager_func_params)); return pflr_->GetOutputDevices(handle_, &output_devices_); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index a98427a9e04d27..7a800f9b2a15d1 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -60,14 +60,19 @@ class FunctionLibraryRuntime; const int64_t kInvalidOpId = -1; -// This struc is used for: -// 1. setting op_id and step_id, is_component_function for single-client +// This struct is used for: +// 1. Setting `op_id` and `step_id`, `is_component_function` for single-client // remote function scenario, -// 2. setting step_id for multi-client parallel_device scenario. +// 2. Setting `step_id` for multi-client parallel_device scenario. +// 3. Supplying an overriding, private `FunctionLibraryDefinition` for component +// functions. struct EagerFunctionParams { int64_t op_id = kInvalidOpId; bool is_component_function; std::optional step_id = std::nullopt; + FunctionLibraryDefinition* func_lib_def_override = + nullptr; // Not owned (owned by `EagerContext`). If not null, functions + // called by the function will be looked up in this library. }; class EagerKernelArgs : public FunctionArgsInterface { @@ -113,8 +118,10 @@ class KernelAndDevice : public core::RefCounted { // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. - virtual Status Init(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector) = 0; + virtual Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) = 0; // Non-multi-device functions are run using regular CallOp and look like // primitive operations from KernelAndDevice perspective. @@ -215,8 +222,10 @@ class KernelAndDeviceOp final : public KernelAndDevice { ~KernelAndDeviceOp() override = default; - Status Init(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector) override; + Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) override; Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, @@ -316,11 +325,15 @@ class KernelAndDeviceFunc : public KernelAndDevice { bool IsCrossProcess() override { return is_cross_process_; } - Status InstantiateFunc(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector); + Status InstantiateFunc( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params); - Status Init(bool log_device_placement, const NodeDef& ndef, - GraphCollector* graph_collector) override; + Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) override; Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index 33122bc4c38105..bda3e5f582fc05 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -118,7 +118,7 @@ void BM_KernelAndDeviceInit(::testing::benchmark::State& state) { KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr, nullptr, env.cpu_device()); for (auto s : state) { - TF_CHECK_OK(k.Init({}, ndef, nullptr)); + TF_CHECK_OK(k.Init({}, ndef, nullptr, std::nullopt)); } } BENCHMARK(BM_KernelAndDeviceInit); @@ -138,7 +138,7 @@ void BM_KernelAndDeviceRun(::testing::benchmark::State& state) { TestEnv env; KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr, nullptr, env.cpu_device()); - TF_CHECK_OK(k.Init({}, ndef, nullptr)); + TF_CHECK_OK(k.Init({}, ndef, nullptr, std::nullopt)); const EagerKernelArgs args(std::move(inputs)); for (auto s : state) { TF_CHECK_OK(k.Run(nullptr, args, &outputs, nullptr, std::nullopt, diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 93371be2fc12ab..1f7613b9ec48c1 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -1586,7 +1586,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // impact. TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a)); - return !trans_a; + // Only rewrite float and bfloat16. + DataType T_m; + TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T_m)); + + return !trans_a && (T_m == DT_FLOAT || T_m == DT_BFLOAT16); } // Check if we are performing pooling on depth or batch. If it is, then we @@ -1864,6 +1868,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { fused_ops == std::vector{"BiasAdd", "Relu"} || fused_ops == std::vector{"BiasAdd", "Relu6"} || fused_ops == std::vector{"BiasAdd", "Elu"} || + fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"} || fused_ops == std::vector{"BiasAdd", "Add"} || fused_ops == std::vector{"BiasAdd", "Add", "Relu"} || fused_ops == std::vector{"BiasAdd", "Add", "Relu6"} || @@ -1899,7 +1904,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return (fused_ops == std::vector{"BiasAdd"} || fused_ops == std::vector{"BiasAdd", "Relu"} || fused_ops == std::vector{"BiasAdd", "Relu6"} || - fused_ops == std::vector{"BiasAdd", "Elu"}); + fused_ops == std::vector{"BiasAdd", "Elu"} || + fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"}); } // Rewrites input node to a new node specified by its matching rewrite info. diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD index a5ace31ae81401..9c9ce942b78e78 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD @@ -211,6 +211,7 @@ cc_library( name = "c_plugin_op_kernel", srcs = ["c_plugin_op_kernel.cc"], hdrs = ["c_plugin_op_kernel.h"], + copts = ["-DTF_CAPI_WEAK"], visibility = ["//visibility:public"], deps = [ ":c_plugin_variable", @@ -300,7 +301,7 @@ cc_library( name = "c_plugin_coordination_service_agent", srcs = ["c_plugin_coordination_service_agent.cc"], hdrs = ["c_plugin_coordination_service_agent.h"], - defines = ["TF_CAPI_WEAK"], + copts = ["-DTF_CAPI_WEAK"], visibility = ["//visibility:public"], deps = [ ":plugin_coordination_service_agent", @@ -352,7 +353,7 @@ cc_library( name = "c_plugin_variable", srcs = ["c_plugin_variable.cc"], hdrs = ["c_plugin_variable.h"], - defines = ["TF_CAPI_WEAK"], + copts = ["-DTF_CAPI_WEAK"], visibility = ["//visibility:public"], deps = [ ":plugin_variable", diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc index a266fe7bcf8f3a..109d9ed62b95b2 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc @@ -83,7 +83,7 @@ Status CPluginOpKernelConstruction::GetInt32AttrList( &total_size, status); TF_RETURN_IF_ERROR(StatusFromTF_Status(status)); - value->reserve(list_size); + value->resize(list_size); TF_OpKernelConstruction_GetAttrInt32List( ctx_, attr_name.data(), value->data(), /*max_vals=*/list_size, status); diff --git a/tensorflow/core/common_runtime/replicate_constants_pass.cc b/tensorflow/core/common_runtime/replicate_constants_pass.cc index d81db57beb23b3..376129bad99719 100644 --- a/tensorflow/core/common_runtime/replicate_constants_pass.cc +++ b/tensorflow/core/common_runtime/replicate_constants_pass.cc @@ -67,13 +67,39 @@ bool HasCpuDevice(const Node* node) { return device.type == "CPU"; } +// Convert the CPU device name to the corresponding CPU device name. If +// multiple local CPU devices are enabled, the CPU device name will also +// contain the device id. +Status DeviceNameToCpuDeviceNameWithDeviceId(const string& device_name, + string* host_device_name) { + DeviceNameUtils::ParsedName device; + if (!DeviceNameUtils::ParseFullName(device_name, &device)) { + return absl::InternalError( + absl::StrCat("Could not parse device name ", device_name)); + } + // If aggressive constant replication is enabled and the dst node is on CPU. + // We just use the device name of the dst for the src. + if (flags::Global().enable_aggressive_constant_replication.value() && + device.type == "CPU") { + *host_device_name = device_name; + } else { + // If not, assigning the corresponding CPU 0 to it. + device.type = "CPU"; + device.has_type = true; + device.id = 0; + device.has_id = true; + *host_device_name = DeviceNameUtils::ParsedNameToString(device); + } + return OkStatus(); +} + // Get the CPU device on the same host as dst. Status GetDestinationCpuDevice(const Node* dst, std::string* device) { if (!dst->has_assigned_device_name()) return absl::AbortedError( absl::StrCat("Node name: ", dst->name(), " has no assigned device.")); - return DeviceNameUtils::DeviceNameToCpuDeviceName(dst->assigned_device_name(), - device); + return DeviceNameToCpuDeviceNameWithDeviceId(dst->assigned_device_name(), + device); } // Collect the successor edges of the constant. Group them by the device of the diff --git a/tensorflow/core/config/BUILD b/tensorflow/core/config/BUILD index 7a2400c2f64206..53f42c5759ecfb 100644 --- a/tensorflow/core/config/BUILD +++ b/tensorflow/core/config/BUILD @@ -95,6 +95,7 @@ py_strict_test( python_version = "PY3", deps = [ ":flags_py", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h index 6a99c548ac9cfc..2061add2f1d4b8 100644 --- a/tensorflow/core/config/flag_defs.h +++ b/tensorflow/core/config/flag_defs.h @@ -53,6 +53,9 @@ class Flags { "Enables the publication of partitioned function graphs " "via StatsPublisherInterface. Disabling this flag can " "reduce memory consumption."); + TF_DECLARE_FLAG(enable_aggressive_constant_replication, true, + "Replicate constants across CPU devices and even for local " + "CPUs within the same task if available.") // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc) }; diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc index 3d0a001aecf903..769a9a4db2d983 100644 --- a/tensorflow/core/config/flags_api_wrapper.cc +++ b/tensorflow/core/config/flags_api_wrapper.cc @@ -52,5 +52,6 @@ PYBIND11_MODULE(flags_pybind, m) { TF_PY_DECLARE_FLAG(tf_shape_default_int64); TF_PY_DECLARE_FLAG(more_stack_traces); TF_PY_DECLARE_FLAG(publish_function_graphs); + TF_PY_DECLARE_FLAG(enable_aggressive_constant_replication); // LINT.ThenChange(//tensorflow/core/config/flag_defs.h) }; diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index a6debf6a378624..2f02dd9b8c88d6 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_not_mobile", @@ -8,6 +7,7 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_protos_all", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -39,6 +39,10 @@ exports_files([ "serialization_utils.h", "split_utils.cc", "split_utils.h", + "file_logger_client_no_op.h", + "file_logger_client_no_op.cc", + "file_logger_client_interface.h", + "file_logger_client_interface.cc", "stats_utils.cc", "stats_utils.h", "tfdataz_metrics.h", @@ -393,7 +397,9 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:stringpiece", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@local_tsl//tsl/platform:statusor", ], ) @@ -439,6 +445,7 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", + "@local_tsl//tsl/platform:types", ], ) @@ -507,14 +514,15 @@ cc_library( deps = [ ":dataset_utils", ":root_dataset", + ":serialization_utils", ":unbounded_thread_pool", "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", - "//tensorflow/core/data:serialization_utils", "//tensorflow/core/framework:graph_proto_cc", "@com_google_absl//absl/memory", "@local_tsl//tsl/platform:env", @@ -630,9 +638,38 @@ cc_library( hdrs = ["utils.h"], # copybara:uncomment copts = ["-Wthread-safety-analysis"], deps = [ + ":file_logger_client_interface", + ":file_logger_client_no_op", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", ], ) + +tf_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + # copybara:uncomment extra_copts = ["-Wthread-safety-analysis"], + deps = [ + ":file_logger_client_interface", + ":file_logger_client_no_op", + ":utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "file_logger_client_interface", + hdrs = ["file_logger_client_interface.h"], + visibility = [ + "//learning/processing/tf_data_logger/client:__subpackages__", + "//tensorflow:internal", + ], +) + +cc_library( + name = "file_logger_client_no_op", + hdrs = ["file_logger_client_no_op.h"], + deps = [":file_logger_client_interface"], +) diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h index c3d489dd855263..5d9a573aad0d3f 100644 --- a/tensorflow/core/data/captured_function.h +++ b/tensorflow/core/data/captured_function.h @@ -290,6 +290,8 @@ class InstantiatedCapturedFunction { FunctionLibraryRuntime::DoneCallback done, const std::shared_ptr& node) const; + std::string func_name() const { return captured_func_->func().name(); } + private: friend class CapturedFunction; diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 0d662a7a938d33..7d0081c2e6de1e 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -1006,8 +1006,8 @@ REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<50>, REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("reduce_array_record_dataset_memory_usage", - RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<10>, + RandomJobSamplePercentage<50>, AllTasks); +REGISTER_DATASET_EXPERIMENT("map_fusion", RandomJobSamplePercentage<0>, AllTasks); } // namespace } // namespace data diff --git a/tensorflow/core/data/file_logger_client_interface.h b/tensorflow/core/data/file_logger_client_interface.h new file mode 100644 index 00000000000000..afa6cda0cf15f5 --- /dev/null +++ b/tensorflow/core/data/file_logger_client_interface.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_INTERFACE_H_ +#define TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_INTERFACE_H_ + +#include +#include + +namespace tensorflow::data { + +// An abstract class to provides an easy and thread safe api to make +// asynchronous calls to the TFDataLoggerService. +// LogFilesAsync is guaranteed to be non blocking. +// The destructor however might be blocking. +class FileLoggerClientInterface { + public: + // Default constructor + FileLoggerClientInterface() = default; + + // Sends file names in `files` to the TFDataLoggerService. Asynchronously. + virtual void LogFilesAsync(std::vector files) = 0; + + // Default destructor. May block depending on implementation of the derived + // class. + virtual ~FileLoggerClientInterface() = default; +}; +} // namespace tensorflow::data + +#endif // TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_INTERFACE_H_ diff --git a/tensorflow/core/data/file_logger_client_no_op.h b/tensorflow/core/data/file_logger_client_no_op.h new file mode 100644 index 00000000000000..65247844f741c4 --- /dev/null +++ b/tensorflow/core/data/file_logger_client_no_op.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_NO_OP_H_ +#define TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_NO_OP_H_ + +#include +#include + +#include "tensorflow/core/data/file_logger_client_interface.h" + +namespace tensorflow::data { + +// Implementation of the abstract class FileLoggerClientInterface, which does +// nothing. It does not allocate any resources and immediately returns in +// LogFilesAsync.3rd This is used in 3rd party version of the tf.data library. +class FileLoggerClientNoOp : public FileLoggerClientInterface { + public: + // Default constructor + FileLoggerClientNoOp() = default; + + // Does not do anything + void LogFilesAsync(std::vector files) override{}; + + // Default destructor + ~FileLoggerClientNoOp() override = default; +}; +} // namespace tensorflow::data + +#endif // TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_NO_OP_H_ diff --git a/tensorflow/core/data/rewrite_utils.cc b/tensorflow/core/data/rewrite_utils.cc index 707e4d8264118d..76c05e6e47f2fc 100644 --- a/tensorflow/core/data/rewrite_utils.cc +++ b/tensorflow/core/data/rewrite_utils.cc @@ -249,7 +249,8 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, } std::unique_ptr GetGrapplerItem( - GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) { + GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks, + bool apply_optimizations) { // Add an identity node as the fetch node, otherwise we might get 'placeholder // is both fed and fetched' errors in some cases when using input list with // placeholder dataset nodes. @@ -285,7 +286,7 @@ std::unique_ptr GetGrapplerItem( // Create Grappler item. tensorflow::grappler::ItemConfig item_config; - item_config.apply_optimizations = true; + item_config.apply_optimizations = apply_optimizations; std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef( "graph", meta_graph_def, item_config); diff --git a/tensorflow/core/data/rewrite_utils.h b/tensorflow/core/data/rewrite_utils.h index 23ea965d67e105..44205dc83b24f5 100644 --- a/tensorflow/core/data/rewrite_utils.h +++ b/tensorflow/core/data/rewrite_utils.h @@ -57,10 +57,13 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, // `dataset_node` is the name of the node corresponding to the dataset. // If `add_fake_sinks` is true, it adds fake sink node to graph and functions to // allow rewriting the actual sink nodes. +// If `apply_optimizations` is true, general grappler optimizations at level +// `tensorflow::OptimizerOptions::L1` are applied to the graph. // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals to // be optimizable, we will no longer need to add fake nodes. std::unique_ptr GetGrapplerItem( - GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks); + GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks, + bool apply_optimizations = true); // Returns the name of the node corresponding to the dataset. It is indicated by // the symbolic `_Retval` node. diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 55ff2bc8122213..bba8a426366329 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringprintf.h" +#include "tsl/platform/host_info.h" namespace tensorflow { namespace data { @@ -46,6 +47,8 @@ constexpr char kDatasetType[] = "Root"; constexpr char kAlgorithm[] = "algorithm"; constexpr char kCpuBudget[] = "cpu_budget"; constexpr char kExperiments[] = "experiments"; +constexpr char kReadRoundtripLatency[] = "read_latency_usec"; +constexpr char kReadResponseBytes[] = "read_bytes"; constexpr char kIntraOpParallelism[] = "intra_op_parallelism"; constexpr char kMemBandwidth[] = "mem_bw_used_megabytes_per_sec"; constexpr char kPrivateThreadpoolSize[] = "threadpool_size"; @@ -277,6 +280,27 @@ class RootDataset::Iterator : public DatasetIterator { "%lld", static_cast( model_node()->TotalMaximumBufferedBytes() / 1.0e6)))); } + const auto io_statistics = tsl::port::GetIOStatistics(); + if (io_statistics.roundtrip_latency_usec.count > 0) { + traceme_metadata.push_back(std::make_pair( + kReadRoundtripLatency, + strings::Printf( + "(count: %lld, mean: %lld, std dev: %lld)", + static_cast( + io_statistics.roundtrip_latency_usec.count), + static_cast(io_statistics.roundtrip_latency_usec.mean), + static_cast( + io_statistics.roundtrip_latency_usec.std_dev)))); + } + if (io_statistics.response_bytes.count > 0) { + traceme_metadata.push_back(std::make_pair( + kReadResponseBytes, + strings::Printf( + "(count: %lld, mean: %lld, std dev: %lld)", + static_cast(io_statistics.response_bytes.count), + static_cast(io_statistics.response_bytes.mean), + static_cast(io_statistics.response_bytes.std_dev)))); + } return traceme_metadata; } diff --git a/tensorflow/core/data/serialization_utils.cc b/tensorflow/core/data/serialization_utils.cc index e07ec49b9137de..01b5a1289e1257 100644 --- a/tensorflow/core/data/serialization_utils.cc +++ b/tensorflow/core/data/serialization_utils.cc @@ -14,12 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/serialization_utils.h" +#include #include #include #include #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/data/compression_utils.h" @@ -30,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { namespace data { @@ -118,20 +121,39 @@ Status ReadElementsFromCheckpoint(IteratorContext* ctx, return OkStatus(); } +Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + int64_t index) { + const std::vector& element = elements[index]; + std::string element_prefix = absl::StrCat(key_prefix, "::", index); + TF_RETURN_IF_ERROR( + writer->WriteScalar(element_prefix, kNumComponents, element.size())); + for (int j = 0; j < element.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j])); + } + return OkStatus(); +} + Status WriteElementsToCheckpoint( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements) { TF_RETURN_IF_ERROR( writer->WriteScalar(key_prefix, kNumElements, elements.size())); for (int i = 0; i < elements.size(); ++i) { - const std::vector& element = elements[i]; - std::string element_prefix = absl::StrCat(key_prefix, "::", i); - TF_RETURN_IF_ERROR( - writer->WriteScalar(element_prefix, kNumComponents, element.size())); - for (int j = 0; j < elements[i].size(); ++j) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j])); - } + TF_RETURN_IF_ERROR(WriteElement(writer, key_prefix, elements, i)); + } + return OkStatus(); +} + +Status UpdateCheckpointElements( + IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + const absl::flat_hash_set& checkpoint_indices) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(key_prefix, kNumElements, elements.size())); + for (int64_t i : checkpoint_indices) { + TF_RETURN_IF_ERROR(WriteElement(writer, key_prefix, elements, i)); } return OkStatus(); } diff --git a/tensorflow/core/data/serialization_utils.h b/tensorflow/core/data/serialization_utils.h index d5e83c32eb488f..b55dfdfb7eca8c 100644 --- a/tensorflow/core/data/serialization_utils.h +++ b/tensorflow/core/data/serialization_utils.h @@ -47,6 +47,15 @@ Status WriteElementsToCheckpoint( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements); +// Updates the dataset elements in the checkpoint for given `checkpoint_indices` +// using the given key prefix, assuming that vector of elements have +// checkpointed these before. The elements can be read back by passing the same +// key prefix to ReadElementsFromCheckpoint. +Status UpdateCheckpointElements( + IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + const absl::flat_hash_set& checkpoint_indices); + // Helper class for reading data from a vector of VariantTensorData objects. class VariantTensorDataReader : public IteratorStateReader { public: diff --git a/tensorflow/core/data/serialization_utils_test.cc b/tensorflow/core/data/serialization_utils_test.cc index ddd424c519841c..5de7acfdc30f53 100644 --- a/tensorflow/core/data/serialization_utils_test.cc +++ b/tensorflow/core/data/serialization_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/data/serialization_utils.h" +#include #include #include #include @@ -203,6 +204,16 @@ class ParameterizedIteratorStateVariantTest } }; +class ParemeterizedCheckpointIndicesTest + : public DatasetOpsTestBase, + public ::testing::WithParamInterface> { + protected: + absl::flat_hash_set GetCheckpointIndices() const { + absl::flat_hash_set checkpoint_indices = GetParam(); + return checkpoint_indices; + } +}; + std::vector> TestCases() { return { CreateTensors(TensorShape{1}, {{1}}), // int64 @@ -216,6 +227,18 @@ std::vector> TestCases() { }; } +std::vector> CheckpointIndicesTestCases() { + return { + {/*checkpoint_indices*/}, + {/*checkpoint_indices*/ 0}, + {/*checkpoint_indices*/ 0, 1}, + {/*checkpoint_indices*/ 0, 1, 2}, + {/*checkpoint_indices*/ 1, 3, 4}, + {/*checkpoint_indices*/ 1, 2, 3, 4}, + {/*checkpoint_indices*/ 0, 1, 2, 3, 4}, + }; +} + TEST_P(ParameterizedIteratorStateVariantTest, EncodeAndDecode) { VariantTensorData data = GetVariantTensorData(); TF_ASSERT_OK_AND_ASSIGN(VariantTensorData result, EncodeAndDecode(data)); @@ -236,9 +259,58 @@ TEST_P(ParameterizedIteratorStateVariantTest, DecodeUncompressed) { } } +TEST_P(ParemeterizedCheckpointIndicesTest, + CheckpointElementsRoundTripUsingIndices) { + std::vector> elements; + elements.push_back(CreateTensors(TensorShape({3}), {{1, 2, 3}})); + elements.push_back(CreateTensors(TensorShape({2}), {{4, 5}})); + elements.push_back( + CreateTensors(TensorShape({5}), {{6, 7, 8, 9, 10}})); + elements.push_back( + CreateTensors(TensorShape({4}), {{11, 12, 13, 14}})); + elements.push_back(CreateTensors(TensorShape({2}), {{15, 16}})); + VariantTensorDataWriter writer; + tstring test_prefix = full_name("test_prefix"); + // Generate checkpoint for entire buffer + absl::flat_hash_set checkpoint_indices_write = {0, 1, 2, 3, 4}; + TF_ASSERT_OK(WriteElementsToCheckpoint(&writer, test_prefix, elements)); + // Update the elements at checkpoint indices + for (auto index : GetCheckpointIndices()) { + elements.at(index) = CreateTensors(TensorShape({1}), {{1}}); + } + TF_ASSERT_OK(UpdateCheckpointElements(&writer, test_prefix, elements, + GetCheckpointIndices())); + std::vector data; + writer.GetData(&data); + + VariantTensorDataReader reader(data); + std::vector> read_elements; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr ctx, + TestContext::Create()); + TF_ASSERT_OK(ReadElementsFromCheckpoint(ctx->iter_ctx(), &reader, test_prefix, + &read_elements)); + + ASSERT_EQ(elements.size(), read_elements.size()); + // Check if checkpoint state of entire buffer is as expected + for (int index = 0; index < elements.size(); ++index) { + std::vector& original = elements[index]; + std::vector& read = read_elements[index]; + + ASSERT_EQ(original.size(), read.size()); + for (int j = 0; j < original.size(); ++j) { + EXPECT_EQ(original[j].NumElements(), read[j].NumElements()); + EXPECT_EQ(original[j].flat()(0), read[j].flat()(0)); + } + } +} + INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedIteratorStateVariantTest, ::testing::ValuesIn(TestCases())); +INSTANTIATE_TEST_SUITE_P(Instantiation, ParemeterizedCheckpointIndicesTest, + ::testing::ValuesIn(CheckpointIndicesTestCases())); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index ee93b3bc916f20..90acfd91600efb 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -1,16 +1,16 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) +load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "get_compatible_with_portable", "tf_grpc_cc_dependencies") load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library", "tf_protos_profiler_service", ) -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "get_compatible_with_portable", "tf_grpc_cc_dependencies") -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package_group( name = "data_transfer_visibility", @@ -739,6 +739,23 @@ cc_library( ], ) +tf_cc_test( + name = "split_provider_test", + srcs = ["split_provider_test.cc"], + # copybara:uncomment extra_copts = ["-Wthread-safety-analysis"], + deps = [ + ":common_proto_cc", + ":split_provider", + ":test_util", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "task_remover", srcs = ["task_remover.cc"], @@ -765,6 +782,7 @@ cc_library( ":thread_safe_buffer", ":worker_proto_cc", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/data:standalone", @@ -847,15 +865,14 @@ cc_library( "//tensorflow/core/framework:node_def_proto_cc", "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/framework:types_proto_cc", - "//tensorflow/core/platform:errors", "//tensorflow/core/platform:path", - "//tensorflow/core/platform:protobuf", - "//tensorflow/core/platform:status", - "//tensorflow/core/platform:statusor", "//tensorflow/core/platform:tstring", "//tensorflow/core/platform:types", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/tensorflow/core/data/service/graph_rewriters.cc b/tensorflow/core/data/service/graph_rewriters.cc index af154691549f2a..114ae7c336cedb 100644 --- a/tensorflow/core/data/service/graph_rewriters.cc +++ b/tensorflow/core/data/service/graph_rewriters.cc @@ -95,10 +95,18 @@ RemoveCompressionMapRewriter::ApplyRemoveCompressionMapRewrite( tensorflow::RewriterConfig::CustomGraphOptimizer config = GetRewriteConfig(); TF_RETURN_IF_ERROR(remove_compression_map.Init(&config)); + // Don't apply general grappler optimizations. Sometimes there is a conflict + // between two applications of these optimizations to the same graph (see + // b/303524867). This conflict isn't worth resolving in the context of this + // rewrite: the point of this rewrite is to remove one node and change one + // reference to it, not to apply any general optimizations. + bool apply_general_grappler_optimizations = false; + GraphDef input_graph = graph_def; TF_ASSIGN_OR_RETURN(std::string dataset_node, GetDatasetNode(input_graph)); std::unique_ptr grappler_item = - GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false); + GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false, + apply_general_grappler_optimizations); GraphDef rewritten_graph; std::unordered_map device_map; diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index 46e420aac744bb..e3a16ec3b8b856 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -1,7 +1,7 @@ # Distributed snapshot library. load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.default.bzl", "tf_grpc_cc_dependencies") +load("//tensorflow:tensorflow.default.bzl", "tf_grpc_cc_dependencies", "tf_kernel_library") load("//tensorflow/core/platform:build_config.bzl", "tf_protos_profiler_service") load("@local_tsl//tsl:tsl.default.bzl", "get_compatible_with_portable") @@ -35,6 +35,7 @@ tf_cc_test( "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(), ) @@ -85,6 +86,28 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "list_snapshot_chunks_dataset_op", + srcs = ["list_snapshot_chunks_dataset_op.cc"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":snapshot_chunk_provider", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/framework:op_requires", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", + ], +) + cc_library( name = "path_utils", srcs = ["path_utils.cc"], @@ -242,6 +265,58 @@ cc_library( ], ) +cc_library( + name = "snapshot_chunk_provider", + srcs = ["snapshot_chunk_provider.cc"], + hdrs = ["snapshot_chunk_provider.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":file_utils", + ":path_utils", + "//tensorflow/core:framework", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/distributed_runtime/rpc:grpc_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", + "@local_tsl//tsl/protobuf:status_proto_cc", + ], +) + +tf_cc_test( + name = "snapshot_chunk_provider_test", + size = "small", + srcs = ["snapshot_chunk_provider_test.cc"], + deps = [ + ":file_utils", + ":path_utils", + ":snapshot_chunk_provider", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/protobuf:status_proto_cc", + ], +) + cc_library( name = "snapshot_stream_writer", srcs = ["snapshot_stream_writer.cc"], diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc index ccd795706ea376..0a91582823f676 100644 --- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc +++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc @@ -32,13 +32,14 @@ limitations under the License. #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/tstring.h" namespace tensorflow { namespace data { namespace { -using testing::ChooseFromDatasets; using testing::CreateDummyDistributedSnapshotMetadata; using ::testing::ElementsAre; using ::testing::IsEmpty; @@ -150,7 +151,8 @@ TEST_P(DistributedSnapshotTest, ChooseFromDatasets) { // choice_dataset = tf.data.Dataset.range(3).repeat() // dataset = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset) TestSnapshotCluster data_service(NumWorkers()); - TF_ASSERT_OK_AND_ASSIGN(DatasetDef dataset, ChooseFromDatasets()); + TF_ASSERT_OK_AND_ASSIGN(DatasetDef dataset, + testing::GetTestDataset("choose_from_datasets")); experimental::DistributedSnapshotMetadata metadata = CreateDummyDistributedSnapshotMetadata(); std::string snapshot_path = LocalTempFilename(); @@ -158,8 +160,8 @@ TEST_P(DistributedSnapshotTest, ChooseFromDatasets) { data_service.dispatcher().Snapshot(dataset, snapshot_path, metadata)); TF_ASSERT_OK(WaitForSnapshotComplete(snapshot_path)); EXPECT_THAT( - testing::ReadSnapshot(snapshot_path, - tsl::io::compression::kNone), + testing::ReadSnapshot(snapshot_path, + tsl::io::compression::kNone), IsOkAndHolds(UnorderedElementsAre("a", "b", "c", "a", "b", "c", "a", "b", "c", "a", "b", "c", "a", "b", "c"))); } diff --git a/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc b/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc new file mode 100644 index 00000000000000..666374ffe7693c --- /dev/null +++ b/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc @@ -0,0 +1,198 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/tstring.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr const char kListSnapshotChunksDataset[] = "ListSnapshotChunksDataset"; +constexpr const char kSnapshotPath[] = "snapshot_path"; + +Tensor ConvertToTensor(absl::string_view s, Allocator* allocator) { + Tensor tensor(allocator, DT_STRING, TensorShape({})); + tensor.scalar()() = tsl::tstring(s); + return tensor; +} + +// TODO(b/297930782): Implement split provider for this dataset. +class ListSnapshotChunksDatasetOp : public DatasetOpKernel { + public: + explicit ListSnapshotChunksDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +class ListSnapshotChunksDatasetOp::Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, tsl::tstring snapshot_path, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + snapshot_path_(std::move(snapshot_path)), + output_types_(output_types), + output_shapes_(output_shapes) {} + + absl::string_view snapshot_path() const { return snapshot_path_; } + + const DataTypeVector& output_dtypes() const override { return output_types_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + int64_t CardinalityInternal(CardinalityOptions options) const override { + // TODO(b/297930782): Implement this. + return kUnknownCardinality; + } + + std::string DebugString() const override { + return name_utils::DatasetDebugString(kListSnapshotChunksDataset); + } + + absl::Status InputDatasets( + std::vector* inputs) const override { + inputs->clear(); + return absl::OkStatus(); + } + + absl::Status CheckExternalState() const override { return absl::OkStatus(); } + + protected: + std::unique_ptr MakeIteratorInternal( + const std::string& prefix) const override; + + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* snapshot_path = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(snapshot_path_, &snapshot_path)); + return b->AddDataset(this, /*inputs=*/{snapshot_path}, output); + } + + private: + class Iterator; + + const tsl::tstring snapshot_path_; + const DataTypeVector output_types_; + const std::vector output_shapes_; +}; + +class ListSnapshotChunksDatasetOp::Dataset::Iterator + : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + absl::Status Initialize(IteratorContext* ctx) override { + if (!snapshot_chunk_provider_) { + snapshot_chunk_provider_ = std::make_unique( + dataset()->snapshot_path(), ctx->env()); + } + return absl::OkStatus(); + } + + private: + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + TF_ASSIGN_OR_RETURN(std::optional chunk, + snapshot_chunk_provider_->GetNext()); + if (!chunk.has_value()) { + *end_of_sequence = true; + return absl::OkStatus(); + } + out_tensors->push_back(ConvertToTensor(*chunk, ctx->allocator({}))); + *end_of_sequence = false; + return absl::OkStatus(); + } + + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + return snapshot_chunk_provider_->Save( + [&](const std::string& key) { return full_name(key); }, writer); + } + + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return snapshot_chunk_provider_->Restore( + [&](const std::string& key) { return full_name(key); }, reader); + } + + std::unique_ptr snapshot_chunk_provider_; +}; + +ListSnapshotChunksDatasetOp::ListSnapshotChunksDatasetOp( + OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); +} + +void ListSnapshotChunksDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + tsl::tstring snapshot_path; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSnapshotPath, &snapshot_path)); + OP_REQUIRES(ctx, !snapshot_path.empty(), + absl::InvalidArgumentError( + "snapshot_path is required to list snapshot chunks.")); + *output = new ListSnapshotChunksDatasetOp::Dataset( + ctx, std::move(snapshot_path), output_types_, output_shapes_); +} + +std::unique_ptr +ListSnapshotChunksDatasetOp::Dataset::MakeIteratorInternal( + const std::string& prefix) const { + return std::make_unique( + ListSnapshotChunksDatasetOp::Dataset::Iterator::Params{ + this, + name_utils::IteratorPrefix(kListSnapshotChunksDataset, prefix)}); +} + +REGISTER_KERNEL_BUILDER(Name(kListSnapshotChunksDataset).Device(DEVICE_CPU), + ListSnapshotChunksDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc index ad2e7ba41bdea8..49ec21ecf6e6b2 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc @@ -125,11 +125,12 @@ class SnapshotChunkDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params) {} + ~Iterator() override { RecordBytesRead(); } + absl::Status Initialize(IteratorContext* ctx) override { reader_ = std::make_unique( TranslateFileName(dataset()->chunk_file_), dataset()->compression_, dataset()->dtypes_, kTFRecordReaderOutputBufferSize); - bytes_read_ = 0; return reader_->Initialize(ctx->env()); } @@ -147,7 +148,6 @@ class SnapshotChunkDatasetOp::Dataset : public DatasetBase { status, " Failed to read tf.data snapshot file: ", dataset()->chunk_file_); ++start_index_; - RecordBytesRead(); return status; } @@ -180,15 +180,12 @@ class SnapshotChunkDatasetOp::Dataset : public DatasetBase { void RecordBytesRead() { uint64_t bytes_read = reader_->BytesRead(); - static auto* bytes_counter = - metrics::GetTFDataBytesReadCounter(kSnapshotChunkDataset); - bytes_counter->IncrementBy(bytes_read - bytes_read_); - bytes_read_ = bytes_read; + metrics::GetTFDataBytesReadCounter(kSnapshotChunkDataset) + ->IncrementBy(bytes_read); } std::unique_ptr reader_; int64_t start_index_ = 0; - uint64_t bytes_read_ = 0; }; const tstring chunk_file_; diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc new file mode 100644 index 00000000000000..b5910e081d1e7b --- /dev/null +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc @@ -0,0 +1,162 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h" + +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/data/service/snapshot/file_utils.h" +#include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "tensorflow/core/framework/dataset.h" +#include "tsl/distributed_runtime/rpc/grpc_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/tstring.h" +#include "tsl/protobuf/status.pb.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kChunksRead[] = "chunks_read"; +constexpr absl::string_view kSetElementDelimiter = ","; + +// Waits for a short period of time before retrying. +void Backoff(int num_retries, tsl::Env* env) { + if (num_retries >= 1) { // Does not backoff for the first try. + env->SleepForMicroseconds(tsl::ComputeBackoffMicroseconds(num_retries - 1)); + } +} + +std::string SetToString(const absl::flat_hash_set& s) { + return absl::StrJoin(s, kSetElementDelimiter); +} + +absl::flat_hash_set SetFromString(absl::string_view s) { + if (s.empty()) { + return {}; + } + std::vector split = absl::StrSplit(s, kSetElementDelimiter); + return absl::flat_hash_set(split.begin(), split.end()); +} + +} // namespace + +SnapshotChunkProvider::SnapshotChunkProvider(absl::string_view snapshot_path, + tsl::Env* env) + : snapshot_path_(snapshot_path), env_(env) {} + +absl::StatusOr> SnapshotChunkProvider::GetNext() + ABSL_LOCKS_EXCLUDED(mu_) { + for (int num_retries = 0;; ++num_retries) { + Backoff(num_retries, env_); + absl::MutexLock l(&mu_); + TF_RETURN_IF_ERROR(snapshot_state_.status); + if (!chunks_unread_.empty()) { + std::string next_chunk = *chunks_unread_.begin(); + chunks_read_.insert(next_chunk); + chunks_unread_.erase(next_chunk); + return tsl::io::JoinPath(CommittedChunksDirectory(snapshot_path_), + next_chunk); + } + if (snapshot_state_.snapshot_is_done) { + return std::nullopt; + } + TF_RETURN_IF_ERROR(UpdateSnapshot()); + } +} + +absl::Status SnapshotChunkProvider::UpdateSnapshot() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // Reads the state files first then reads the chunks. If we read chunks before + // reading the state files, the writer could write more chunks in between, and + // we may see the DONE file but miss those final chunks. + TF_ASSIGN_OR_RETURN(snapshot_state_, GetSnapshotState()); + TF_RETURN_IF_ERROR(snapshot_state_.status); + TF_ASSIGN_OR_RETURN(std::vector chunks, GetAvailableChunks()); + for (absl::string_view chunk : chunks) { + if (!chunks_read_.contains(chunk)) { + chunks_unread_.insert(std::string(chunk)); + } + } + return absl::OkStatus(); +} + +absl::StatusOr +SnapshotChunkProvider::GetSnapshotState() { + std::string error_file_path = SnapshotErrorFilePath(snapshot_path_); + if (env_->FileExists(error_file_path).ok()) { + StatusProto status_proto; + TF_RETURN_IF_ERROR(ReadTextProto(env_, error_file_path, &status_proto)); + absl::Status status = tsl::StatusFromProto(status_proto); + if (status.ok()) { + return absl::InternalError(absl::StrCat( + "Unexpected snapshot ERROR file contains an OK status at ", + error_file_path, ".")); + } + return SnapshotState(status); + } + return SnapshotState( + env_->FileExists(SnapshotDoneFilePath(snapshot_path_)).ok()); +} + +absl::StatusOr> +SnapshotChunkProvider::GetAvailableChunks() { + absl::StatusOr> status_or_chunks = + GetChildren(CommittedChunksDirectory(snapshot_path_), env_); + if (status_or_chunks.ok()) { + return *std::move(status_or_chunks); + } else if (absl::IsNotFound(status_or_chunks.status())) { + return std::vector{}; + } + return status_or_chunks.status(); +} + +absl::Status SnapshotChunkProvider::Save( + std::function full_name, + IteratorStateWriter* writer) { + absl::MutexLock l(&mu_); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kChunksRead), SetToString(chunks_read_))); + return absl::OkStatus(); +} + +absl::Status SnapshotChunkProvider::Restore( + std::function full_name, + IteratorStateReader* reader) { + absl::MutexLock l(&mu_); + tsl::tstring chunks_read; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kChunksRead), &chunks_read)); + chunks_read_ = SetFromString(chunks_read); + return UpdateSnapshot(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h new file mode 100644 index 00000000000000..17d932ea38d5ce --- /dev/null +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h @@ -0,0 +1,100 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_CHUNK_PROVIDER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_CHUNK_PROVIDER_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/dataset.h" +#include "tsl/platform/env.h" + +namespace tensorflow { +namespace data { + +// Provides the next chunk to read. Blocks until the next chunk is unavailable, +// or all the chunks have been read. This class is thread-safe. +class SnapshotChunkProvider { + public: + SnapshotChunkProvider(absl::string_view snapshot_path, tsl::Env* env); + virtual ~SnapshotChunkProvider() = default; + SnapshotChunkProvider(const SnapshotChunkProvider&) = delete; + SnapshotChunkProvider& operator=(const SnapshotChunkProvider&) = delete; + + // Returns the absolute file path of next snapshot chunk to read. If there is + // no available chunk, blocks until the next chunk is unavailable, or all the + // chunks are read. Returns std::nullopt if all chunks have been read. + absl::StatusOr> GetNext(); + + // Supports checkpointing. + absl::Status Save(std::function full_name, + IteratorStateWriter* writer); + absl::Status Restore(std::function full_name, + IteratorStateReader* reader); + + // TODO(b/297930782): Support cancellation. + + private: + // State of the snapshot. + struct SnapshotState { + SnapshotState() = default; + explicit SnapshotState(bool snapshot_is_done) + : snapshot_is_done(snapshot_is_done) {} + explicit SnapshotState(absl::Status status) : status(std::move(status)) {} + + // True if the snapshot is done without errors. + bool snapshot_is_done = false; + + // Non-OK status if writing the snapshot fails. + absl::Status status = absl::OkStatus(); + }; + + // Updates the snapshot state and available chunks. + absl::Status UpdateSnapshot(); + + // Reads the DONE or ERROR file and returns a SnapshotState indicating whether + // the snapshot is complete. + absl::StatusOr GetSnapshotState(); + + // Reads the available chunks from disk and returns a vector of chunk file + // names. + absl::StatusOr> GetAvailableChunks(); + + const std::string snapshot_path_; + tsl::Env* const env_; + + mutable absl::Mutex mu_; + + // The set of read chunks. + absl::flat_hash_set chunks_read_ ABSL_GUARDED_BY(mu_); + + // The set of unread chunks. + absl::flat_hash_set chunks_unread_ ABSL_GUARDED_BY(mu_); + + // State of the snapshot. + SnapshotState snapshot_state_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_CHUNK_PROVIDER_H_ diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc new file mode 100644 index 00000000000000..28e31cae660e56 --- /dev/null +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc @@ -0,0 +1,242 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/data/service/snapshot/file_utils.h" +#include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/protobuf/status.pb.h" + +namespace tensorflow { +namespace data { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAreArray; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +absl::StatusOr CreateSnapshotDirectory() { + std::string snapshot_path; + if (!tsl::Env::Default()->LocalTempFilename(&snapshot_path)) { + return absl::FailedPreconditionError( + "Failed to create local temp file for snapshot."); + } + TF_RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir( + CommittedChunksDirectory(snapshot_path))); + return snapshot_path; +} + +absl::Status WriteChunk(absl::string_view snapshot_path, + absl::string_view chunk_file) { + return AtomicallyWriteStringToFile( + tsl::io::JoinPath(CommittedChunksDirectory(snapshot_path), chunk_file), + "", tsl::Env::Default()); +} + +absl::Status SetDone(absl::string_view snapshot_path) { + return AtomicallyWriteStringToFile(SnapshotDoneFilePath(snapshot_path), "", + tsl::Env::Default()); +} + +absl::Status SetStatus(absl::string_view snapshot_path, + const absl::Status& status) { + return AtomicallyWriteTextProto(SnapshotErrorFilePath(snapshot_path), + tsl::StatusToProto(status), + tsl::Env::Default()); +} + +absl::StatusOr> GetAllChunks( + SnapshotChunkProvider& snapshot_chunk_provider) { + std::vector chunks; + while (true) { + TF_ASSIGN_OR_RETURN(std::optional chunk, + snapshot_chunk_provider.GetNext()); + if (!chunk.has_value()) { + break; + } + chunks.push_back(*chunk); + } + return chunks; +} + +std::vector JoinPaths(absl::string_view snapshot_path, + const std::vector chunks) { + std::vector joined_chunks; + for (absl::string_view chunk : chunks) { + joined_chunks.push_back( + tsl::io::JoinPath(CommittedChunksDirectory(snapshot_path), chunk)); + } + return joined_chunks; +} + +TEST(SnapshotChunkProviderTest, EmptySnapshot) { + TF_ASSERT_OK_AND_ASSIGN(std::string snapshot_path, CreateSnapshotDirectory()); + TF_ASSERT_OK(SetDone(snapshot_path)); + + SnapshotChunkProvider snapshot_chunk_provider(snapshot_path, + tsl::Env::Default()); + EXPECT_THAT(GetAllChunks(snapshot_chunk_provider), IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetAllChunks(snapshot_chunk_provider), IsOkAndHolds(IsEmpty())); +} + +TEST(SnapshotChunkProviderTest, SingleReader) { + TF_ASSERT_OK_AND_ASSIGN(std::string snapshot_path, CreateSnapshotDirectory()); + std::vector chunks = {"chunk_0_0_0", "chunk_1_1_1", + "chunk_2_2_2", "chunk_3_3_3", + "chunk_4_4_4"}; + for (absl::string_view chunk : chunks) { + TF_ASSERT_OK(WriteChunk(snapshot_path, chunk)); + } + TF_ASSERT_OK(SetDone(snapshot_path)); + + SnapshotChunkProvider snapshot_chunk_provider(snapshot_path, + tsl::Env::Default()); + EXPECT_THAT(GetAllChunks(snapshot_chunk_provider), + IsOkAndHolds( + UnorderedElementsAreArray(JoinPaths(snapshot_path, chunks)))); +} + +TEST(SnapshotChunkProviderTest, WaitForSnapshot) { + std::string snapshot_path; + ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&snapshot_path)); + + absl::Mutex mu; + std::vector result; // Guarded by `mu`. + std::unique_ptr reader_thread = + absl::WrapUnique(tsl::Env::Default()->StartThread( + /*thread_options=*/{}, /*name=*/"Reader", + [&snapshot_path, &mu, &result]() { + SnapshotChunkProvider snapshot_chunk_provider(snapshot_path, + tsl::Env::Default()); + TF_ASSERT_OK_AND_ASSIGN(std::vector chunks, + GetAllChunks(snapshot_chunk_provider)); + absl::MutexLock l(&mu); + result = std::move(chunks); + })); + + { // The reader should wait when there are no chunks. + absl::MutexLock l(&mu); + EXPECT_TRUE(result.empty()); + } + + TF_ASSERT_OK(tsl::Env::Default()->RecursivelyCreateDir( + CommittedChunksDirectory(snapshot_path))); + TF_ASSERT_OK(WriteChunk(snapshot_path, "chunk_0_0_0")); + TF_ASSERT_OK(SetDone(snapshot_path)); + + // The reader should be able to get chunks now. + reader_thread.reset(); + absl::MutexLock l(&mu); + EXPECT_THAT(result, UnorderedElementsAreArray( + JoinPaths(snapshot_path, {"chunk_0_0_0"}))); +} + +TEST(SnapshotChunkProviderTest, ConcurrentReadWrite) { + TF_ASSERT_OK_AND_ASSIGN(std::string snapshot_path, CreateSnapshotDirectory()); + + const int num_readers = 10; + absl::Mutex mu; + SnapshotChunkProvider snapshot_chunk_provider(snapshot_path, + tsl::Env::Default()); + std::vector result; // Guarded by `mu`. + std::vector> reader_threads; + for (int i = 0; i < num_readers; ++i) { + reader_threads.push_back(absl::WrapUnique(tsl::Env::Default()->StartThread( + /*thread_options=*/{}, /*name=*/absl::StrCat("Reader_", i), + [&snapshot_chunk_provider, &mu, &result]() { + while (true) { + tsl::Env::Default()->SleepForMicroseconds(25); + TF_ASSERT_OK_AND_ASSIGN(std::optional chunk, + snapshot_chunk_provider.GetNext()); + if (!chunk.has_value()) { + break; + } + absl::MutexLock l(&mu); + result.push_back(std::move(*chunk)); + } + }))); + } + + int num_streams = 10, num_chunks_per_stream = 50; + std::vector> stream_threads; + for (int i = 0; i < num_streams; ++i) { + stream_threads.push_back(absl::WrapUnique(tsl::Env::Default()->StartThread( + /*thread_options=*/{}, /*name=*/absl::StrCat("Writer_", i), + [&snapshot_path, num_chunks_per_stream, i]() { + for (int j = 0; j < num_chunks_per_stream; ++j) { + std::string filename = absl::StrCat("chunk_", i, "_", j); + TF_ASSERT_OK(WriteChunk(snapshot_path, filename)); + tsl::Env::Default()->SleepForMicroseconds(35); + } + }))); + } + + stream_threads.clear(); + TF_ASSERT_OK(SetDone(snapshot_path)); + + reader_threads.clear(); + std::vector expected; + for (int i = 0; i < num_streams; ++i) { + for (int j = 0; j < num_chunks_per_stream; ++j) { + expected.push_back(absl::StrCat("chunk_", i, "_", j)); + } + } + EXPECT_THAT(result, + UnorderedElementsAreArray(JoinPaths(snapshot_path, expected))); +} + +TEST(SnapshotChunkProviderTest, SnapshotError) { + TF_ASSERT_OK_AND_ASSIGN(std::string snapshot_path, CreateSnapshotDirectory()); + std::unique_ptr reader_thread = + absl::WrapUnique(tsl::Env::Default()->StartThread( + /*thread_options=*/{}, /*name=*/"Reader", [&snapshot_path]() { + SnapshotChunkProvider snapshot_chunk_provider(snapshot_path, + tsl::Env::Default()); + EXPECT_THAT( + GetAllChunks(snapshot_chunk_provider), + StatusIs(absl::StatusCode::kFailedPrecondition, "Test error.")); + })); + + TF_ASSERT_OK(WriteChunk(snapshot_path, "chunk_0_0_0")); + TF_ASSERT_OK(WriteChunk(snapshot_path, "chunk_1_0_0")); + TF_ASSERT_OK(WriteChunk(snapshot_path, "chunk_2_0_0")); + TF_ASSERT_OK( + SetStatus(snapshot_path, absl::FailedPreconditionError("Test error."))); + reader_thread.reset(); +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc index f02652c49a0ca3..b4bd82b5ad2184 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/split_provider.h" #include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/platform/status.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" @@ -99,7 +100,7 @@ absl::Status SnapshotManager::Start(const SnapshotRequest& request) } tsl::mutex_lock l(mu_); TF_ASSIGN_OR_RETURN(sources_, CreateSources(request.dataset())); - TF_ASSIGN_OR_RETURN(num_total_splits_, CountSplits()); + TF_ASSIGN_OR_RETURN(num_total_splits_, GetSplitsCardinality()); TF_RETURN_IF_ERROR(WriteOnDiskSkeleton()); TF_RETURN_IF_ERROR(WriteOnDiskMetadata(request)); metadata_ = request.metadata(); @@ -120,6 +121,31 @@ SnapshotManager::CreateSources(const DatasetDef& dataset_def) const return sources; } +absl::StatusOr SnapshotManager::GetSplitsCardinality() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (ShouldCountSplits()) { + return CountSplits(); + } + + int64_t num_splits = 0; + for (const auto& source : sources_) { + if (source.split_provider->Cardinality() > 0) { + num_splits += source.split_provider->Cardinality(); + } + } + return num_splits; +} + +bool SnapshotManager::ShouldCountSplits() const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (const auto& source : sources_) { + if (source.split_provider->Cardinality() == kUnknownCardinality) { + return true; + } + } + return false; +} + absl::StatusOr SnapshotManager::CountSplits() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t num_splits = 0; @@ -210,7 +236,7 @@ absl::Status SnapshotManager::ReadOnDiskMetadata() ReadBinaryProto(env_, DatasetDefFilePath(path_), &dataset_def)); TF_ASSIGN_OR_RETURN(sources_, CreateSources(dataset_def)); - TF_ASSIGN_OR_RETURN(num_total_splits_, CountSplits()); + TF_ASSIGN_OR_RETURN(num_total_splits_, GetSplitsCardinality()); return absl::OkStatus(); } diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.h b/tensorflow/core/data/service/snapshot/snapshot_manager.h index 77a24ce915bc73..fe903fde98847f 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.h +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -247,6 +247,11 @@ class SnapshotManager { // Creates sources for the specified dataset. absl::StatusOr> CreateSources( const DatasetDef& dataset_def) const; + // Returns the total number of splits. + absl::StatusOr GetSplitsCardinality(); + // Returns true if we need to count the total number of splits for progress + // reporting. + bool ShouldCountSplits() const; // Counts the number of splits for a single repetition of the data in // `sources_`. absl::StatusOr CountSplits(); diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc index d22e9573a15148..fcebb32e82a539 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc @@ -93,8 +93,6 @@ class ElementOrErrorIterator : public TaskIterator { int64_t Cardinality() const override { return elements_.size(); } - std::optional GetProcessingTimeNsec() const override { return 1.0e7; } - private: const std::vector> elements_; int64_t next_ = 0; diff --git a/tensorflow/core/data/service/split_provider_test.cc b/tensorflow/core/data/service/split_provider_test.cc new file mode 100644 index 00000000000000..08adc907058af2 --- /dev/null +++ b/tensorflow/core/data/service/split_provider_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/split_provider.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/test_util.h" +#include "tensorflow/core/framework/dataset.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +std::vector GetCardinalities( + const std::vector>& split_providers) { + std::vector cardinalities; + for (const auto& split_provider : split_providers) { + cardinalities.push_back(split_provider->Cardinality()); + } + return cardinalities; +} + +TEST(SplitProviderTest, RangeCardinality) { + DatasetDef range_dataset = testing::RangeDataset(10); + std::vector> split_providers; + TF_ASSERT_OK(CreateSplitProviders(range_dataset, split_providers)); + EXPECT_THAT(GetCardinalities(split_providers), UnorderedElementsAre(10)); +} + +class RepeatedSplitProviderTest + : public ::testing::TestWithParam> { + public: + int64_t Range() const { return std::get<0>(GetParam()); } + int64_t RepeatCount() const { return std::get<1>(GetParam()); } + int64_t ExpectedCardinality() const { return std::get<2>(GetParam()); } +}; + +// Test cases for the `RepeatedDatasetCardinality` test. The tuples specify +// {range, repeat count, expected cardinality}. +constexpr std::array, 5> + kRepeatedSplitProviderTestCases{{{9, 9, 81}, + {9, 0, 0}, + {9, -1, kInfiniteCardinality}, + {0, -1, 0}, + {-1, 1, 0}}}; + +TEST_P(RepeatedSplitProviderTest, RepeatedDatasetCardinality) { + TF_ASSERT_OK_AND_ASSIGN( + DatasetDef repeated_dataset, + testing::GetTestDataset( + "repeated_dataset", + {absl::StrCat(Range()), absl::StrCat(RepeatCount())})); + std::vector> split_providers; + TF_ASSERT_OK(CreateSplitProviders(repeated_dataset, split_providers)); + EXPECT_THAT(GetCardinalities(split_providers), + ElementsAre(ExpectedCardinality())); +} + +INSTANTIATE_TEST_SUITE_P(MyGroup, RepeatedSplitProviderTest, + ::testing::ValuesIn(kRepeatedSplitProviderTestCases)); + +TEST(SplitProviderTest, EnumerateCardinality) { + TF_ASSERT_OK_AND_ASSIGN(DatasetDef enumerate_dataset, + testing::GetTestDataset("enumerate_dataset")); + std::vector> split_providers; + TF_ASSERT_OK(CreateSplitProviders(enumerate_dataset, split_providers)); + EXPECT_THAT(GetCardinalities(split_providers), + UnorderedElementsAre(3, kInfiniteCardinality)); +} + +TEST(SplitProviderTest, ChooseFromDatasetsCardinality) { + TF_ASSERT_OK_AND_ASSIGN(DatasetDef sample_from_datasets, + testing::GetTestDataset("choose_from_datasets")); + std::vector> split_providers; + TF_ASSERT_OK(CreateSplitProviders(sample_from_datasets, split_providers)); + EXPECT_THAT(GetCardinalities(split_providers), + UnorderedElementsAre(5, 5, 5, kInfiniteCardinality)); +} + +TEST(SplitProviderTest, SampleFromDatasetsCardinality) { + TF_ASSERT_OK_AND_ASSIGN(DatasetDef sample_from_datasets, + testing::GetTestDataset("sample_from_datasets")); + std::vector> split_providers; + TF_ASSERT_OK(CreateSplitProviders(sample_from_datasets, split_providers)); + EXPECT_THAT(GetCardinalities(split_providers), + UnorderedElementsAre(5, 5, 5, kInfiniteCardinality)); +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/task_runner.cc b/tensorflow/core/data/service/task_runner.cc index c0240f27235557..6c169d2bf90a93 100644 --- a/tensorflow/core/data/service/task_runner.cc +++ b/tensorflow/core/data/service/task_runner.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -73,8 +74,8 @@ Status StandaloneTaskIterator::Restore( return iterator_->Restore(saved_iterator); } -std::optional StandaloneTaskIterator::GetProcessingTimeNsec() const { - return iterator_->GetProcessingTimeNsec(); +std::shared_ptr StandaloneTaskIterator::model() const { + return iterator_->model(); } Status TaskRunner::Create(const experimental::WorkerConfig& worker_config, @@ -168,10 +169,8 @@ void FirstComeFirstServedTaskRunner::Cancel() { buffer_.Cancel(errors::Cancelled("tf.data service FCFS task is cancelled.")); } -std::optional FirstComeFirstServedTaskRunner::GetProcessingTimeNsec() - TF_LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - return iterator_->GetProcessingTimeNsec(); +std::shared_ptr FirstComeFirstServedTaskRunner::model() const { + return model_; } CachingTaskRunner::CachingTaskRunner(std::unique_ptr iterator, @@ -223,8 +222,8 @@ void CachingTaskRunner::Cancel() { fcfs_task_runner_.Cancel(); } -std::optional CachingTaskRunner::GetProcessingTimeNsec() { - return fcfs_task_runner_.GetProcessingTimeNsec(); +std::shared_ptr CachingTaskRunner::model() const { + return fcfs_task_runner_.model(); } RoundRobinTaskRunner::RoundRobinTaskRunner( @@ -361,8 +360,8 @@ void RoundRobinTaskRunner::Cancel() { new_round_cv_.notify_all(); } -std::optional RoundRobinTaskRunner::GetProcessingTimeNsec() { - return prefetch_thread_.GetProcessingTimeNsec(); +std::shared_ptr RoundRobinTaskRunner::model() const { + return prefetch_thread_.model(); } PrefetchThread::PrefetchThread(std::unique_ptr iterator, @@ -447,8 +446,8 @@ Status PrefetchThread::GetStatus() { return status_; } -std::optional PrefetchThread::GetProcessingTimeNsec() const { - return iterator_->GetProcessingTimeNsec(); +std::shared_ptr PrefetchThread::model() const { + return iterator_->model(); } } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/task_runner.h b/tensorflow/core/data/service/task_runner.h index a1db121ac602ca..565a7f4727a477 100644 --- a/tensorflow/core/data/service/task_runner.h +++ b/tensorflow/core/data/service/task_runner.h @@ -63,12 +63,8 @@ class TaskIterator { "Restoring from a tf.data service task iterator is unsupported."); } - // Returns the time it takes the pipeline associated with this task iterator - // to process an element. - // Returns std::nullopt if there is not currently enough information to - // determine the processing time, e.g. because not enough data has been - // produced yet from the iterator. - virtual std::optional GetProcessingTimeNsec() const = 0; + // Returns the dataset model for performance analysis. + virtual std::shared_ptr model() const { return nullptr; } }; // Implementation of TaskIterator wrapping a standalone iterator. @@ -83,7 +79,7 @@ class StandaloneTaskIterator : public TaskIterator { int64_t Cardinality() const override; StatusOr> Save() override; Status Restore(const std::vector& saved_iterator) override; - std::optional GetProcessingTimeNsec() const override; + std::shared_ptr model() const override; private: std::unique_ptr dataset_; @@ -102,14 +98,10 @@ class TaskRunner { // Gets the next element for the given request. virtual Status GetNext(const GetElementRequest& req, GetElementResult& result) = 0; - // Returns the time it takes the pipeline associated with this task runner to - // process an element. Returns 0 if the model is null or empty. - // Returns std::nullopt if there is not currently enough information to - // determine the processing time, e.g. because not enough data has been - // produced yet from the iterator. - virtual std::optional GetProcessingTimeNsec() = 0; // Cancels in-progress `GetNext` requests. virtual void Cancel() = 0; + // Returns the dataset model for performance analysis. + virtual std::shared_ptr model() const = 0; }; // A task runner which provides elements on a first-come first-served basis. @@ -127,7 +119,7 @@ class FirstComeFirstServedTaskRunner : public TaskRunner { void Cancel() override; - std::optional GetProcessingTimeNsec() override TF_LOCKS_EXCLUDED(mu_); + std::shared_ptr model() const override; private: // Function to continually prefetch the next element. Returns an error if the @@ -140,6 +132,7 @@ class FirstComeFirstServedTaskRunner : public TaskRunner { // Gets the next element from the input iterator. StatusOr GetNextFromInputIterator() TF_LOCKS_EXCLUDED(mu_); + const std::shared_ptr model_; mutex mu_; std::unique_ptr iterator_ TF_GUARDED_BY(mu_); int64_t element_index_ TF_GUARDED_BY(mu_) = 0; @@ -173,7 +166,8 @@ class CachingTaskRunner : public TaskRunner { // return a Cancelled status. void Cancel() override; - std::optional GetProcessingTimeNsec() override; + // Returns the dataset model for performance analysis. + std::shared_ptr model() const override; private: // The `GetElementResultSequence` generates a sequence of elements from the @@ -224,7 +218,8 @@ class PrefetchThread { std::vector>& out); // Returns the status for any failures encountered by the prefetch thread. Status GetStatus(); - std::optional GetProcessingTimeNsec() const; + // Returns the dataset model for performance analysis. + std::shared_ptr model() const; private: const std::unique_ptr iterator_; @@ -269,7 +264,7 @@ class RoundRobinTaskRunner : public TaskRunner { Status GetNext(const GetElementRequest& req, GetElementResult& result) override; void Cancel() override; - std::optional GetProcessingTimeNsec() override; + std::shared_ptr model() const override; private: // Prepares a full round of data. `wait_us` indicates how long to wait before diff --git a/tensorflow/core/data/service/task_runner_test.cc b/tensorflow/core/data/service/task_runner_test.cc index 5650e28627e631..0c1ef895742b0c 100644 --- a/tensorflow/core/data/service/task_runner_test.cc +++ b/tensorflow/core/data/service/task_runner_test.cc @@ -77,8 +77,6 @@ class RangeIterator : public TaskIterator { return repeat_ ? kInfiniteCardinality : range_; } - std::optional GetProcessingTimeNsec() const override { return 1.0e7; } - private: const int64_t range_; const bool repeat_; @@ -96,8 +94,6 @@ class InfiniteRangeIterator : public TaskIterator { int64_t Cardinality() const override { return kInfiniteCardinality; } - std::optional GetProcessingTimeNsec() const override { return 1.0e7; } - private: int64_t next_ = 0; }; @@ -121,8 +117,6 @@ class ElementOrErrorIterator : public TaskIterator { int64_t Cardinality() const override { return elements_.size(); } - std::optional GetProcessingTimeNsec() const override { return 1.0e7; } - private: const std::vector> elements_; int64_t next_ = 0; diff --git a/tensorflow/core/data/service/test_util.cc b/tensorflow/core/data/service/test_util.cc index 8d7abf3a9eb540..766fe07c4f7469 100644 --- a/tensorflow/core/data/service/test_util.cc +++ b/tensorflow/core/data/service/test_util.cc @@ -19,8 +19,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" +#include "absl/strings/substitute.h" #include "tensorflow/core/data/dataset_test_base.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/framework/function.h" @@ -33,13 +36,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/tstring.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/struct.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" namespace tensorflow { namespace data { @@ -53,9 +54,11 @@ using ::tensorflow::test::function::NDef; constexpr int64_t kShardHint = -1; constexpr const char kTestdataDir[] = "tensorflow/core/data/service/testdata"; +constexpr const char kEnumerateDatasetFile[] = "enumerate_dataset.pbtxt"; constexpr const char kInterleaveTextlineDatasetFile[] = "interleave_textline_dataset.pbtxt"; constexpr const char kChooseFromDatasetsFile[] = "choose_from_datasets.pbtxt"; +constexpr const char kSampleFromDatasetsFile[] = "sample_from_datasets.pbtxt"; NodeDef GetMapNode(absl::string_view name, absl::string_view input_node_name, absl::string_view function_name) { @@ -77,16 +80,16 @@ FunctionDef XTimesX() { /*ret_def=*/{{"y", "y:z:0"}}); } -Status CreateTestFiles(const std::vector& filenames, - const std::vector& contents) { +absl::Status CreateTestFiles(const std::vector& filenames, + const std::vector& contents) { if (filenames.size() != contents.size()) { - return errors::InvalidArgument( + return absl::InvalidArgumentError( "The number of files does not match with the contents."); } for (int i = 0; i < filenames.size(); ++i) { TF_RETURN_IF_ERROR(WriteDataToFile(filenames[i], contents[i].data())); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -96,6 +99,32 @@ std::string LocalTempFilename() { return path; } +absl::StatusOr GetTestDataset( + absl::string_view dataset_name, const std::vector& args) { + std::string graph_file = + io::JoinPath(kTestdataDir, absl::StrCat(dataset_name, ".pbtxt")); + std::string graph_str; + TF_RETURN_IF_ERROR(ReadFileToString(Env::Default(), graph_file, &graph_str)); + if (args.size() == 1) { + graph_str = absl::Substitute(graph_str, args[0]); + } else if (args.size() == 2) { + graph_str = absl::Substitute(graph_str, args[0], args[1]); + } else if (args.size() == 3) { + graph_str = absl::Substitute(graph_str, args[0], args[1], args[2]); + } else if (args.size() > 3) { + return absl::UnimplementedError( + "GetTestDataset does not support more than 3 arguments."); + } + + DatasetDef dataset; + if (!tsl::protobuf::TextFormat::ParseFromString(graph_str, + dataset.mutable_graph())) { + return absl::FailedPreconditionError( + absl::StrCat("Can't parse ", graph_file, " as text proto.")); + } + return dataset; +} + DatasetDef RangeDataset(int64_t range) { DatasetDef dataset_def; *dataset_def.mutable_graph() = GDef( @@ -182,14 +211,6 @@ DatasetDef InfiniteDataset() { return dataset_def; } -StatusOr ChooseFromDatasets() { - DatasetDef dataset; - std::string graph_file = io::JoinPath(kTestdataDir, kChooseFromDatasetsFile); - TF_RETURN_IF_ERROR( - ReadTextProto(Env::Default(), graph_file, dataset.mutable_graph())); - return dataset; -} - experimental::DistributedSnapshotMetadata CreateDummyDistributedSnapshotMetadata() { StructuredValue decoded_spec; @@ -204,7 +225,7 @@ CreateDummyDistributedSnapshotMetadata() { return metadata; } -StatusOr InterleaveTextlineDataset( +absl::StatusOr InterleaveTextlineDataset( const std::vector& filenames, const std::vector& contents) { TF_RETURN_IF_ERROR(CreateTestFiles(filenames, contents)); @@ -222,11 +243,11 @@ StatusOr InterleaveTextlineDataset( return dataset; } -Status WaitWhile(std::function()> f) { +absl::Status WaitWhile(std::function()> f) { while (true) { TF_ASSIGN_OR_RETURN(bool result, f()); if (!result) { - return OkStatus(); + return absl::OkStatus(); } Env::Default()->SleepForMicroseconds(10 * 1000); // 10ms. } diff --git a/tensorflow/core/data/service/test_util.h b/tensorflow/core/data/service/test_util.h index a175c543cbeef8..2180675b74e16f 100644 --- a/tensorflow/core/data/service/test_util.h +++ b/tensorflow/core/data/service/test_util.h @@ -20,9 +20,10 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" @@ -35,6 +36,13 @@ namespace testing { // Creates a local tempfile and returns the path. std::string LocalTempFilename(); +// Creates a dataset graph for testing. `dataset_name` is one of the filenames +// defined in `testdata` (without `.pbtxt`). `args` specifies arguments passed +// to the dataset. These args appear as `$0`, `$1`, etc, in the dataset +// definition and will be replaced with the specified args. +absl::StatusOr GetTestDataset( + absl::string_view dataset_name, const std::vector& args = {}); + // Returns a test dataset representing // tf.data.Dataset.range(range). Useful for testing dataset graph execution. DatasetDef RangeDataset(int64_t range); @@ -51,14 +59,6 @@ DatasetDef RangeDatasetWithShardHint(int64_t range); // tf.data.Dataset.range(100000000).repeat(). DatasetDef InfiniteDataset(); -// Returns a test dataset representing -// datasets = [tf.data.Dataset.from_tensor_slices(["a", "a", "a", "a", "a"]), -// tf.data.Dataset.from_tensor_slices(["b", "b", "b", "b", "b"]), -// tf.data.Dataset.from_tensor_slices(["c", "c", "c", "c", "c"])] -// choice_dataset = tf.data.Dataset.range(3).repeat() -// dataset = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset) -StatusOr ChooseFromDatasets(); - // Returns a distributed snapshot metadata for a dummy dataset. experimental::DistributedSnapshotMetadata CreateDummyDistributedSnapshotMetadata(); @@ -67,14 +67,14 @@ CreateDummyDistributedSnapshotMetadata(); // tf.data.Dataset.from_tensor_slices(["filenames"]).interleave( // lambda filepath: tf.data.TextLineDataset(filepath), // cycle_length=10) -StatusOr InterleaveTextlineDataset( +absl::StatusOr InterleaveTextlineDataset( const std::vector& filenames, const std::vector& contents); // Repeatedly calls `f()`, blocking until `f()` returns `false`. // // Returns an error if `f()` returns an error. -Status WaitWhile(std::function()> f); +absl::Status WaitWhile(std::function()> f); // TODO(b/229726259): Make EqualsProto available in Googletest // (Public feature request: https://github.com/google/googletest/issues/1761). diff --git a/tensorflow/core/data/service/test_util_test.cc b/tensorflow/core/data/service/test_util_test.cc index 9608163e9f47e8..0cf43eb404631e 100644 --- a/tensorflow/core/data/service/test_util_test.cc +++ b/tensorflow/core/data/service/test_util_test.cc @@ -158,8 +158,9 @@ TEST(TestUtilTest, InterleaveTextlineEmptyFiles) { EXPECT_THAT(GetIteratorOutput(*iterator), IsOkAndHolds(IsEmpty())); } -TEST(TestUtilTest, ChooseFromDatasets) { - TF_ASSERT_OK_AND_ASSIGN(const DatasetDef dataset_def, ChooseFromDatasets()); +TEST(TestUtilTest, GetTestDataset) { + TF_ASSERT_OK_AND_ASSIGN(const DatasetDef dataset_def, + GetTestDataset("choose_from_datasets")); standalone::Dataset::Params params; std::unique_ptr dataset; TF_ASSERT_OK( diff --git a/tensorflow/core/data/service/testdata/enumerate_dataset.pbtxt b/tensorflow/core/data/service/testdata/enumerate_dataset.pbtxt new file mode 100644 index 00000000000000..c0066f9da50e72 --- /dev/null +++ b/tensorflow/core/data/service/testdata/enumerate_dataset.pbtxt @@ -0,0 +1,288 @@ +# proto-file: third_party/tensorflow/core/framework/graph.proto +# proto-message: GraphDef +# +# Proto content generated by +# +# import tensorflow as tf +# +# dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c"]) +# dataset = dataset.enumerate() +# +# g = tf.compat.v1.GraphDef() +# g.ParseFromString(dataset._as_serialized_graph().numpy()) +# print(g) + +node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } +} +node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 9223372036854775807 + } + } + } +} +node { + name: "Const/_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } +} +node { + name: "RangeDataset/_3" + op: "RangeDataset" + input: "Const/_0" + input: "Const/_1" + input: "Const/_2" + attr { + key: "metadata" + value { + s: "\n\016RangeDataset:9" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "replicate_on_split" + value { + b: true + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + } + } + } +} +node { + name: "Const/_4" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\001\001\001abc" + } + } + } +} +node { + name: "TensorSliceDataset/_5" + op: "TensorSliceDataset" + input: "Const/_4" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "is_files" + value { + b: false + } + } + attr { + key: "metadata" + value { + s: "\n\024TensorSliceDataset:7" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "replicate_on_split" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_STRING + } + } + } + } + } +} +node { + name: "ZipDataset/_6" + op: "ZipDataset" + input: "RangeDataset/_3" + input: "TensorSliceDataset/_5" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "metadata" + value { + s: "\n\rZipDataset:10" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + type: DT_STRING + } + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + args { + type_id: TFT_TENSOR + args { + type_id: TFT_STRING + } + } + } + } + } +} +node { + name: "dataset" + op: "_Retval" + input: "ZipDataset/_6" + attr { + key: "T" + value { + type: DT_VARIANT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 1700 +} diff --git a/tensorflow/core/data/service/testdata/repeated_dataset.pbtxt b/tensorflow/core/data/service/testdata/repeated_dataset.pbtxt new file mode 100644 index 00000000000000..8dfca4717c97e3 --- /dev/null +++ b/tensorflow/core/data/service/testdata/repeated_dataset.pbtxt @@ -0,0 +1,215 @@ +# proto-file: third_party/tensorflow/core/framework/graph.proto +# proto-message: GraphDef +# +# Proto content generated by +# +# import tensorflow as tf +# +# dataset = tf.data.Dataset.range($0) +# dataset = dataset.repeat($1) +# +# g = tf.compat.v1.GraphDef() +# g.ParseFromString(dataset._as_serialized_graph().numpy()) +# print(g) + +node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } +} +node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: $0 + } + } + } +} +node { + name: "Const/_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } +} +node { + name: "RangeDataset/_3" + op: "RangeDataset" + input: "Const/_0" + input: "Const/_1" + input: "Const/_2" + attr { + key: "metadata" + value { + s: "\n\017RangeDataset:15" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "replicate_on_split" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + } + } + } +} +node { + name: "Const/_4" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: $1 + } + } + } +} +node { + name: "RepeatDataset/_5" + op: "RepeatDataset" + input: "RangeDataset/_3" + input: "Const/_4" + attr { + key: "metadata" + value { + s: "\n\020RepeatDataset:16" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + } + } + } +} +node { + name: "dataset" + op: "_Retval" + input: "RepeatDataset/_5" + attr { + key: "T" + value { + type: DT_VARIANT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 1700 +} diff --git a/tensorflow/core/data/service/testdata/sample_from_datasets.pbtxt b/tensorflow/core/data/service/testdata/sample_from_datasets.pbtxt new file mode 100644 index 00000000000000..4b6276b1923db7 --- /dev/null +++ b/tensorflow/core/data/service/testdata/sample_from_datasets.pbtxt @@ -0,0 +1,762 @@ +# proto-file: third_party/tensorflow/core/framework/graph.proto +# proto-message: GraphDef +# +# Proto content generated by +# +# import tensorflow as tf +# +# datasets = [tf.data.Dataset.from_tensor_slices(["a", "a", "a", "a", "a"]), +# tf.data.Dataset.from_tensor_slices(["b", "b", "b", "b", "b"]), +# tf.data.Dataset.from_tensor_slices(["c", "c", "c", "c", "c"])] +# dataset = tf.data.Dataset.sample_from_datasets( +# datasets, weights=[1.0] * len(datasets)) +# +# g = tf.compat.v1.GraphDef() +# g.ParseFromString(dataset._as_serialized_graph().numpy()) +# print(g) + +node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } +} +node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } +} +node { + name: "RandomDataset/_2" + op: "RandomDataset" + input: "Const/_0" + input: "Const/_1" + attr { + key: "metadata" + value { + s: "\n\017RandomDataset:3" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + } + } + } +} +node { + name: "Const/_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } +} +node { + name: "Const/_4" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } +} +node { + name: "BatchDatasetV2/_5" + op: "BatchDatasetV2" + input: "RandomDataset/_2" + input: "Const/_3" + input: "Const/_4" + attr { + key: "metadata" + value { + s: "\n\020BatchDatasetV2:4" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "parallel_copy" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + } + } + } +} +node { + name: "Const/_6" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" + } + } + } +} +node { + name: "MapDataset/_7" + op: "MapDataset" + input: "BatchDatasetV2/_5" + input: "Const/_6" + attr { + key: "Targuments" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "f" + value { + func { + name: "__inference_Dataset_map_select_dataset_constant_logits_24" + } + } + } + attr { + key: "metadata" + value { + s: "\n\014MapDataset:5" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "preserve_cardinality" + value { + b: true + } + } + attr { + key: "use_inter_op_parallelism" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_INT64 + } + } + } + } + } +} +node { + name: "Const/_8" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 5 + } + } + tensor_content: "\001\001\001\001\001aaaaa" + } + } + } +} +node { + name: "TensorSliceDataset/_9" + op: "TensorSliceDataset" + input: "Const/_8" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "is_files" + value { + b: false + } + } + attr { + key: "metadata" + value { + s: "\n\024TensorSliceDataset:0" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "replicate_on_split" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_STRING + } + } + } + } + } +} +node { + name: "Const/_10" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 5 + } + } + tensor_content: "\001\001\001\001\001bbbbb" + } + } + } +} +node { + name: "TensorSliceDataset/_11" + op: "TensorSliceDataset" + input: "Const/_10" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "is_files" + value { + b: false + } + } + attr { + key: "metadata" + value { + s: "\n\024TensorSliceDataset:1" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "replicate_on_split" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_STRING + } + } + } + } + } +} +node { + name: "Const/_12" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 5 + } + } + tensor_content: "\001\001\001\001\001ccccc" + } + } + } +} +node { + name: "TensorSliceDataset/_13" + op: "TensorSliceDataset" + input: "Const/_12" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "is_files" + value { + b: false + } + } + attr { + key: "metadata" + value { + s: "\n\024TensorSliceDataset:2" + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "replicate_on_split" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_STRING + } + } + } + } + } +} +node { + name: "DirectedInterleaveDataset/_14" + op: "DirectedInterleaveDataset" + input: "MapDataset/_7" + input: "TensorSliceDataset/_9" + input: "TensorSliceDataset/_11" + input: "TensorSliceDataset/_13" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "stop_on_empty_dataset" + value { + b: false + } + } + experimental_type { + type_id: TFT_PRODUCT + args { + type_id: TFT_DATASET + args { + type_id: TFT_PRODUCT + args { + type_id: TFT_TENSOR + args { + type_id: TFT_STRING + } + } + } + } + } +} +node { + name: "dataset" + op: "_Retval" + input: "DirectedInterleaveDataset/_14" + attr { + key: "T" + value { + type: DT_VARIANT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { + function { + signature { + name: "__inference_Dataset_map_select_dataset_constant_logits_24" + input_arg { + name: "args_0" + type: DT_INT64 + } + input_arg { + name: "statelessmultinomial_logits" + type: DT_FLOAT + } + output_arg { + name: "identity" + type: DT_INT64 + } + } + node_def { + name: "StatelessMultinomial/num_samples" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "StatelessMultinomial" + op: "StatelessMultinomial" + input: "statelessmultinomial_logits" + input: "StatelessMultinomial/num_samples:output:0" + input: "args_0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tseed" + value { + type: DT_INT64 + } + } + attr { + key: "output_dtype" + value { + type: DT_INT64 + } + } + } + node_def { + name: "Squeeze" + op: "Squeeze" + input: "StatelessMultinomial:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "squeeze_dims" + value { + list { + i: 0 + i: 1 + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "Squeeze:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_tf_data_function" + value { + b: true + } + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "args_0" + } + } + } + } + arg_attr { + key: 1 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + dim { + size: 3 + } + } + } + } + } + } + } + } +} +versions { + producer: 1700 +} diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index ce1a3feed067fd..70e8311c2d0b3d 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/dataset.pb.h" #include "tensorflow/core/framework/metrics.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -615,11 +616,14 @@ std::vector DataServiceWorkerImpl::GetActiveTasks() const mutex_lock task_lock(task->mu); task_initialized = task->initialized; } - if (task_initialized && task->task_runner != nullptr) { - std::optional processing_time_nsec = - task->task_runner->GetProcessingTimeNsec(); - active_task.set_processing_time_nsec( - processing_time_nsec ? processing_time_nsec.value() : 0.0); + + if (task_initialized && task->task_runner != nullptr && + task->task_runner->model() != nullptr) { + std::shared_ptr model = task->task_runner->model(); + double processing_time_nsec = model->ComputeSnapshotProcessingTimeNsec(); + if (processing_time_nsec > 0) { + active_task.set_processing_time_nsec(processing_time_nsec); + } } active_tasks.push_back(std::move(active_task)); } diff --git a/tensorflow/core/data/split_utils.cc b/tensorflow/core/data/split_utils.cc index 350c79b0897a72..da75c168126fb0 100644 --- a/tensorflow/core/data/split_utils.cc +++ b/tensorflow/core/data/split_utils.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" +#include "tsl/platform/types.h" namespace tensorflow { namespace data { @@ -78,6 +79,15 @@ absl::Status IndexSplitProvider::Restore( return reader->ReadScalar(full_name(kIndex), &i_); } +int64_t IndexSplitProvider::Cardinality() const { + // RandomDataset uses kint64max to simulate infinite splits. + // See RandomDatasetOp::Dataset::MakeSplitProviders. + if (n_ == tsl::kint64max) { + return kInfiniteCardinality; + } + return n_; +} + ShardingSplitProvider::ShardingSplitProvider( int64_t num_shards, int64_t shard_index, std::shared_ptr split_provider) diff --git a/tensorflow/core/data/split_utils.h b/tensorflow/core/data/split_utils.h index 0801d9afd546e7..a0fdef8d2d2213 100644 --- a/tensorflow/core/data/split_utils.h +++ b/tensorflow/core/data/split_utils.h @@ -42,6 +42,7 @@ class IndexSplitProvider : public SplitProvider { IteratorStateWriter* writer) override; absl::Status Restore(std::function full_name, IteratorStateReader* reader) override; + int64_t Cardinality() const override; private: tsl::mutex mu_; diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index 1790b29730249c..04a425170b27be 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_handle_cache.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -116,16 +117,7 @@ Status Iterator::Restore(const std::vector& saved_iterator) { return iterator_->Restore(ctx_.get(), &reader); } -std::optional Iterator::GetProcessingTimeNsec() const { - if (ctx_->model() == nullptr) return std::nullopt; - - double processing_time_nsec = - ctx_->model()->ComputeSnapshotProcessingTimeNsec(); - if (processing_time_nsec > 0) - return processing_time_nsec; - else - return std::nullopt; -} +std::shared_ptr Iterator::model() const { return ctx_->model(); } Status Dataset::FromGraph(Params params, const GraphDef& graph_def, std::unique_ptr* result) { diff --git a/tensorflow/core/data/standalone.h b/tensorflow/core/data/standalone.h index 0854869fb67a30..5de0d81b274b30 100644 --- a/tensorflow/core/data/standalone.h +++ b/tensorflow/core/data/standalone.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_handle_cache.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -89,12 +90,9 @@ class Iterator { // Restores the iterator from a checkpoint. `saved_iterator` is the serialized // iterator saved by calling `Save()`. Status Restore(const std::vector& saved_iterator); - // Returns the time it takes the pipeline associated with this iterator - // to process an element. - // Returns std::nullopt if there is not currently enough information to - // determine the processing time, e.g. because not enough data has been - // produced yet from the iterator. - std::optional GetProcessingTimeNsec() const; + + // Returns the dataset model for performance analysis. + std::shared_ptr model() const; private: friend class Dataset; diff --git a/tensorflow/core/data/standalone_test.cc b/tensorflow/core/data/standalone_test.cc index 964ec803a32df0..54f438b1cc2308 100644 --- a/tensorflow/core/data/standalone_test.cc +++ b/tensorflow/core/data/standalone_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace data { @@ -523,21 +524,15 @@ TEST(Scalar, Standalone) { GraphDef graph_def; protobuf::TextFormat::ParseFromString(test_case.graph_string, &graph_def); std::unique_ptr dataset; - auto s = Dataset::FromGraph({}, graph_def, &dataset); - TF_EXPECT_OK(s); + TF_EXPECT_OK(Dataset::FromGraph({}, graph_def, &dataset)); std::unique_ptr iterator; - s = dataset->MakeIterator(&iterator); - TF_EXPECT_OK(s); - - std::optional processing_time_nsec = - iterator->GetProcessingTimeNsec(); - EXPECT_EQ(processing_time_nsec, std::nullopt); + TF_EXPECT_OK(dataset->MakeIterator(&iterator)); + EXPECT_DOUBLE_EQ(iterator->model()->ComputeSnapshotProcessingTimeNsec(), 0); bool end_of_input = false; for (int num_outputs = 0; !end_of_input; ++num_outputs) { std::vector outputs; - s = iterator->GetNext(&outputs, &end_of_input); - TF_EXPECT_OK(s); + TF_EXPECT_OK(iterator->GetNext(&outputs, &end_of_input)); if (!end_of_input) { EXPECT_EQ(outputs[0].scalar()(), test_case.expected_outputs[num_outputs]); @@ -548,9 +543,7 @@ TEST(Scalar, Standalone) { // Wait for an optimization round in the pipeline model. absl::SleepFor(absl::Seconds(1)); - processing_time_nsec = iterator->GetProcessingTimeNsec(); - EXPECT_NE(processing_time_nsec, std::nullopt); - EXPECT_LT(0, processing_time_nsec.value()); + EXPECT_GT(iterator->model()->ComputeSnapshotProcessingTimeNsec(), 0); } } @@ -562,10 +555,7 @@ TEST(NoAutotune, Standalone) { TF_EXPECT_OK(Dataset::FromGraph({}, graph_def, &dataset)); std::unique_ptr iterator; TF_EXPECT_OK(dataset->MakeIterator(&iterator)); - - std::optional processing_time_nsec = - iterator->GetProcessingTimeNsec(); - EXPECT_EQ(processing_time_nsec, std::nullopt); + EXPECT_EQ(iterator->model(), nullptr); bool end_of_input = false; for (int num_outputs = 0; !end_of_input; ++num_outputs) { @@ -580,10 +570,8 @@ TEST(NoAutotune, Standalone) { // Wait for an optimization round in the pipeline model. absl::SleepFor(absl::Seconds(1)); - processing_time_nsec = iterator->GetProcessingTimeNsec(); - // Model should not be created and `GetProcessingTimeNsec()` should return - // `nullopt`. - EXPECT_EQ(processing_time_nsec, std::nullopt); + // Model should not be created. + EXPECT_EQ(iterator->model(), nullptr); } } // namespace diff --git a/tensorflow/core/data/utils.cc b/tensorflow/core/data/utils.cc index 7d346dcbecd319..73f8a75587e97e 100644 --- a/tensorflow/core/data/utils.cc +++ b/tensorflow/core/data/utils.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/utils.h" +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "tensorflow/core/data/file_logger_client_interface.h" +#include "tensorflow/core/data/file_logger_client_no_op.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/protobuf/data_service.pb.h" @@ -44,5 +47,9 @@ absl::StatusOr DisableCompressionAtRuntime( return false; } +std::unique_ptr CreateFileLoggerClient() { + return std::make_unique(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils.h b/tensorflow/core/data/utils.h index d80431c9680ab8..00fe795f9c7f3e 100644 --- a/tensorflow/core/data/utils.h +++ b/tensorflow/core/data/utils.h @@ -15,11 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DATA_UTILS_H_ #define TENSORFLOW_CORE_DATA_UTILS_H_ +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "tensorflow/core/data/file_logger_client_interface.h" #include "tensorflow/core/protobuf/data_service.pb.h" namespace tensorflow { @@ -48,6 +50,9 @@ std::string LocalityOptimizedPath(const std::string& path); absl::StatusOr DisableCompressionAtRuntime( const std::string& data_transfer_protocol, DeploymentMode deployment_mode); +// Creates a instance of a class derived from FileLoggerClientInterface. +std::unique_ptr CreateFileLoggerClient(); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/utils_test.cc b/tensorflow/core/data/utils_test.cc new file mode 100644 index 00000000000000..1f908acb278b59 --- /dev/null +++ b/tensorflow/core/data/utils_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/utils.h" + +#include + +#include +#include "tensorflow/core/data/file_logger_client_interface.h" +#include "tensorflow/core/data/file_logger_client_no_op.h" + +namespace tensorflow::data { +namespace { + +TEST(Util, CreateFileLoggerClient) { + std::unique_ptr client = CreateFileLoggerClient(); + EXPECT_NE(dynamic_cast(client.get()), nullptr); +} + +TEST(Util, DefaultDataTransferProtocol) { + EXPECT_EQ(DefaultDataTransferProtocol(), "grpc"); +} + +TEST(TranslateFileName, NoOp) { + constexpr char file[] = "/home/tfdata/file1"; + EXPECT_EQ(TranslateFileName(file), file); +} + +TEST(TranslateFileName, EmptyPath) { + constexpr char file[] = ""; + EXPECT_EQ(TranslateFileName(file), file); +} + +TEST(TranslateFileName, TfDataPath) { + constexpr char file[] = "tfdata/file1"; + EXPECT_EQ(TranslateFileName(file), file); +} + +TEST(LocalityOptimizedPath, NoOp) { + constexpr char file[] = "/home/tfdata/file1"; + EXPECT_EQ(LocalityOptimizedPath(file), file); +} + +TEST(LocalityOptimizedPath, EmptyPath) { + constexpr char file[] = ""; + EXPECT_EQ(LocalityOptimizedPath(file), file); +} + +TEST(LocalityOptimizedPath, TfDataPath) { + constexpr char file[] = "tfdata/file1"; + EXPECT_EQ(LocalityOptimizedPath(file), file); +} + +} // namespace +} // namespace tensorflow::data diff --git a/tensorflow/core/distributed_runtime/README.md b/tensorflow/core/distributed_runtime/README.md index d22cd2a45bc68e..b4220beeae5f5f 100644 --- a/tensorflow/core/distributed_runtime/README.md +++ b/tensorflow/core/distributed_runtime/README.md @@ -4,5 +4,7 @@ This directory contains the initial open-source implementation of the distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process communication. -To learn how to use the distributed runtime to create a TensorFlow cluster, -see the [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed) How-To. +To learn how to use the distributed runtime to create a TensorFlow cluster, see +the +[Distributed TensorFlow](https://www.tensorflow.org/guide/distributed_training) +How-To. diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index c4a7af7c6a26fd..0261268a589e2c 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -78,36 +79,36 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(Status, ReportError, (const Status& error), (override)); MOCK_METHOD(Status, Shutdown, (), (override)); MOCK_METHOD(Status, Reset, (), (override)); - MOCK_METHOD(StatusOr, GetKeyValue, (const std::string& key), + MOCK_METHOD(StatusOr, GetKeyValue, (std::string_view key), (override)); MOCK_METHOD(StatusOr, GetKeyValue, (const char* key, int64_t key_size), (override)); MOCK_METHOD(StatusOr, GetKeyValue, - (const std::string& key, absl::Duration timeout), (override)); + (std::string_view key, absl::Duration timeout), (override)); MOCK_METHOD(std::shared_ptr, GetKeyValueAsync, - (const std::string& key, StatusOrValueCallback done), (override)); - MOCK_METHOD(StatusOr, TryGetKeyValue, (const std::string& key), + (std::string_view key, StatusOrValueCallback done), (override)); + MOCK_METHOD(StatusOr, TryGetKeyValue, (std::string_view key), (override)); MOCK_METHOD(StatusOr>, GetKeyValueDir, - (const std::string& key), (override)); + (std::string_view key), (override)); MOCK_METHOD(void, GetKeyValueDirAsync, - (const std::string& key, StatusOrValueDirCallback done), + (std::string_view key, StatusOrValueDirCallback done), (override)); MOCK_METHOD(Status, InsertKeyValue, - (const std::string& key, const std::string& value), (override)); + (std::string_view key, std::string_view value), (override)); MOCK_METHOD(Status, InsertKeyValue, (const char* key, int64_t key_size, const char* value, int64_t value_size), (override)); - MOCK_METHOD(Status, DeleteKeyValue, (const std::string& key), (override)); + MOCK_METHOD(Status, DeleteKeyValue, (std::string_view key), (override)); MOCK_METHOD(Status, DeleteKeyValue, (const char* key, int64_t key_size), (override)); MOCK_METHOD(Status, UpdateKeyValue, - (const std::string& key, const std::string& value), (override)); + (std::string_view key, std::string_view value), (override)); MOCK_METHOD(Status, StartWatchKey, - (const std::string& key, ChangedKeyValuesCallback on_change), + (std::string_view key, ChangedKeyValuesCallback on_change), (override)); - MOCK_METHOD(Status, StopWatchKey, (const std::string& key), (override)); + MOCK_METHOD(Status, StopWatchKey, (std::string_view key), (override)); MOCK_METHOD(void, WaitAtBarrierAsync, (const std::string& barrier_id, absl::Duration timeout, const std::vector& tasks, StatusCallback done), @@ -117,7 +118,7 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(StatusOr, GetEnv, (), (override)); MOCK_METHOD(void, SetError, (const Status& error), (override)); MOCK_METHOD(Status, ActivateWatch, - (const std::string& key, + (std::string_view key, (const std::map&)), (override)); }; diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 46e86a42be6734..a3a1c8ae937db7 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -1,9 +1,9 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_grpc_cc_dependencies") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", ) +load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_grpc_cc_dependencies") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -102,6 +102,7 @@ cc_library( ":remote_tensor_handle", "//tensorflow/c/eager:immediate_execution_distributed_manager", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index f2fd43ca853156..f6f3bf1ee1668c 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/nccl/collective_communicator.h" #include "tensorflow/core/platform/errors.h" @@ -48,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringprintf.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tsl/distributed_runtime/preemption/preemption_notifier.h" #include "tsl/protobuf/coordination_config.pb.h" @@ -55,13 +57,14 @@ namespace tensorflow { namespace eager { namespace { -Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, +Status GetNumRetvals(FunctionLibraryDefinition* func_lib_def, + const string& op_name, const google::protobuf::Map& attrs, int* num_retvals) { const tensorflow::OpRegistrationData* op_reg_data = nullptr; auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); if (absl::IsNotFound(status)) { - status = context->FindFunctionOpData(op_name, &op_reg_data); + status = func_lib_def->LookUp(op_name, &op_reg_data); } TF_RETURN_IF_ERROR(status); @@ -100,14 +103,27 @@ Status GetEagerOperationAndNumRetvals(const Operation& operation, const char* name = operation.name().c_str(); // Shorthand std::optional remote_func_params = std::nullopt; + FunctionLibraryDefinition* func_lib_def; if (operation.is_function()) { if (operation.is_component_function()) { + func_lib_def = + eager_context->GetComponentFunctionFunctionLibraryDefinition( + operation.name()); + if (func_lib_def == nullptr) { + return absl::InternalError( + absl::StrCat("Could not find function library for registered " + "component function: ", + operation.name())); + } remote_func_params = {operation.id(), /*is_component_function=*/true, - operation.func_step_id()}; + operation.func_step_id(), func_lib_def}; } else { + func_lib_def = eager_context->FuncLibDef(); remote_func_params = {operation.id(), /*is_component_function=*/false, - std::nullopt}; + std::nullopt, /*func_lib_def=*/nullptr}; } + } else { + func_lib_def = eager_context->FuncLibDef(); } TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false, eager_executor, remote_func_params)); @@ -143,7 +159,7 @@ Status GetEagerOperationAndNumRetvals(const Operation& operation, } // TODO(nareshmodi): Consider caching this. - return GetNumRetvals(eager_context, operation.name(), operation.attrs(), + return GetNumRetvals(func_lib_def, operation.name(), operation.attrs(), num_retvals); } @@ -770,9 +786,14 @@ Status EagerServiceImpl::RegisterFunction( const RegisterFunctionOp& register_function, EagerContext* eager_context) { // If the function is a component of a multi-device function, we only need to // register it locally. - return eager_context->AddFunctionDef( - register_function.function_def(), register_function.library(), - register_function.is_component_function()); + if (register_function.is_component_function()) { + return eager_context->AddComponentFunction(register_function.function_def(), + register_function.library()); + } else { + return eager_context->AddFunctionDef(register_function.function_def(), + register_function.library(), + /*add_to_local_only=*/false); + } } Status EagerServiceImpl::RemoveFunction(const RemoveFunctionOp& remove_function, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 79f8ccb21d934a..2ab6631de71d9b 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -309,6 +309,46 @@ tensorflow::FunctionDef MatMulFunction() { return def; } +tensorflow::FunctionDef MatMulTransposeFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'MatMulFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'm'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'matmul'" + " op: 'MatMul'" + " input: 'a'" + " input: 'a'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: 'transpose_a'" + " value {" + " b: true" + " }" + " }" + " }" + " ret {" + " key: 'm'" + " value: 'matmul:product'" + " }", + &def)); + return def; +} + tensorflow::FunctionDef MatMulNestedFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( @@ -710,15 +750,178 @@ TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) { TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) { RegisterFunctionOp register_op; *register_op.mutable_function_def() = MatMulFunction(); + register_op.set_is_component_function(true); TestComponentFunction(register_op, "MatMulFunction", false); } TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) { RegisterFunctionOp register_op; *register_op.mutable_function_def() = SingleRecvNodeFunction(); + register_op.set_is_component_function(true); TestComponentFunction(register_op, "SingleRecvNodeFunction", true); } +TEST_F(EagerServiceImplFunctionTest, ComponentNestedFunctionTest) { + RegisterFunctionOp register_op; + *register_op.mutable_function_def() = MatMulNestedFunction(); + *register_op.mutable_library()->add_function() = MatMulFunction(); + register_op.set_is_component_function(true); + TestComponentFunction(register_op, "MatMulNestedFunction", false); +} + +TEST_F(EagerServiceImplFunctionTest, ComponentNestedFunctionWithNameClashTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + uint64 context_id = random::New64(); + + // Create context. + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_context_id(context_id); + CreateContextResponse response; + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + // Register first function. + { + EnqueueRequest enqueue_request; + enqueue_request.set_context_id(context_id); + RegisterFunctionOp* register_op = + enqueue_request.add_queue()->mutable_register_function(); + *register_op->mutable_function_def() = MatMulNestedFunction(); + *register_op->mutable_library()->add_function() = MatMulFunction(); + register_op->set_is_component_function(true); + EnqueueResponse enqueue_response; + TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request, + &enqueue_response)); + } + + // Register second function. + // In the second registration, the library contains a function named + // "MatMulFunction" but a different body. + { + EnqueueRequest enqueue_request; + enqueue_request.set_context_id(context_id); + RegisterFunctionOp* register_op = + enqueue_request.add_queue()->mutable_register_function(); + + *register_op->mutable_function_def() = MatMulNestedFunction(); + register_op->mutable_function_def()->mutable_signature()->set_name( + "MatMulNestedTransposeFunction"); + *register_op->mutable_library()->add_function() = MatMulTransposeFunction(); + register_op->set_is_component_function(true); + EnqueueResponse enqueue_response; + TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request, + &enqueue_response)); + } + + // First run an op to generate input for the functions. + EnqueueRequest remote_enqueue_request; + remote_enqueue_request.set_context_id(context_id); + EnqueueResponse remote_enqueue_response; + + std::unordered_map const_attrs; + AttrValue val; + val.set_type(tensorflow::DataType::DT_FLOAT); + const_attrs.insert({"dtype", val}); + val.Clear(); + SetTensorProto(val.mutable_tensor()); + const_attrs.insert({"value", val}); + AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, + "/job:localhost/replica:0/task:0/device:CPU:0", + &remote_enqueue_request); + TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, + &remote_enqueue_response)); + + { + // Run first function with input from the previous op. + RunComponentFunctionRequest run_comp_func_request; + run_comp_func_request.set_context_id(context_id); + RunComponentFunctionResponse run_comp_func_response; + const int output_num = 5; + AddOperationToRunComponentFunctionRequest( + 2, "MatMulNestedFunction", {std::make_pair(1, 0)}, + std::unordered_map(), + "/job:localhost/replica:0/task:0/device:CPU:0", output_num, + &run_comp_func_request); + + CallOptions call_opts; + Notification n; + Status status; + eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, + &run_comp_func_response, + [&status, &n](const Status& s) { + status.Update(s); + n.Notify(); + }); + n.WaitForNotification(); + + TF_ASSERT_OK(status); + // Retrieve the output. + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + context_id, RemoteTensorHandleInternal(2, output_num), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + + auto actual = t->flat(); + EXPECT_EQ(4, actual.size()); + + EXPECT_EQ(7, actual(0)); + EXPECT_EQ(10, actual(1)); + EXPECT_EQ(15, actual(2)); + EXPECT_EQ(22, actual(3)); + } + + { + // Run second function with input from the constant op. The result should + // be different, because we are using the transposed implementation of + // MatMulFunction in the second function's library. + RunComponentFunctionRequest run_comp_func_request; + run_comp_func_request.set_context_id(context_id); + RunComponentFunctionResponse run_comp_func_response; + const int output_num = 5; + AddOperationToRunComponentFunctionRequest( + 3, "MatMulNestedTransposeFunction", {std::make_pair(1, 0)}, + std::unordered_map(), + "/job:localhost/replica:0/task:0/device:CPU:0", output_num, + &run_comp_func_request); + + CallOptions call_opts; + Notification n; + Status status; + eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, + &run_comp_func_response, + [&status, &n](const Status& s) { + status.Update(s); + n.Notify(); + }); + n.WaitForNotification(); + + TF_ASSERT_OK(status); + // Retrieve the output. + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + context_id, RemoteTensorHandleInternal(3, output_num), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + + auto actual = t->flat(); + EXPECT_EQ(4, actual.size()); + + EXPECT_EQ(10, actual(0)); + EXPECT_EQ(14, actual(1)); + EXPECT_EQ(14, actual(2)); + EXPECT_EQ(20, actual(3)); + } + + CloseContextRequest close_context_request; + close_context_request.set_context_id(context_id); + close_context_request.set_context_view_id(0); + CloseContextResponse close_context_response; + TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, + &close_context_response)); +} + class FunctionWithRemoteInputsTest : public EagerServiceImplTest { public: FunctionWithRemoteInputsTest() @@ -987,7 +1190,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { // Instantiate MatMulFunction on remote_device. const NodeDef node_def = MatMulFunctionNodeDef(); - TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr)); + TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. gtl::InlinedVector input_tensors = {TensorValue()}; @@ -1042,7 +1245,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { // Instantiate MatMulFunction on remote_device. const NodeDef node_def = MatMulFunctionNodeDef(); - TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr)); + TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. gtl::InlinedVector input_tensors = {TensorValue()}; diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index 4cefc9433c2556..bd5bc39622b9d6 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -60,7 +60,8 @@ Status CreateUncachedKernelAndDeviceOp( const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); return kernel->get()->Init(ctx.LogDevicePlacement(), ndef, - /*graph_collector=*/nullptr); + /*graph_collector=*/nullptr, + /*eager_func_params=*/std::nullopt); } // This gets a unique wire ID. We add a random identifier so that if the diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 6aabf3ce209d7d..148e58a5b008c5 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -42,7 +42,8 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { std::unique_ptr request, Device* device, uint64 context_view_id, EagerClient* eager_client, CancellationManager* cancellation_manager, - const NodeDef& ndef, FunctionLibraryDefinition* lib_def, + const NodeDef& ndef, + const FunctionLibraryDefinition* lib_def, const gtl::InlinedVector& inputs, absl::Span retvals) : AsyncRemoteExecuteNode(), diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 27c92c157eb7ae..01a6cfa158390a 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -712,6 +712,7 @@ cc_library( hdrs = ["resource_base.h"], visibility = default_visibility + [ "//learning/brain/google/data/core/kernels:__pkg__", + "//learning/deepmind/tensorflow/queues:__pkg__", "//learning/deepmind/tensorflow/sstable:__pkg__", ], deps = [ diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index a174234a86e483..5fc2119aaace5f 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -400,6 +400,12 @@ class SplitProvider { // Restores the state of this split provider. virtual Status Restore(std::function full_name, IteratorStateReader* reader) = 0; + // Returns the number of splits: + // - If there are a finite number of splits, returns a non-negative count. + // - If there are an infinite number of splits, returns kInfiniteCardinality. + // - If the number of splits is unknown or can't be efficiently computed, + // returns kUnknownCardinality. + virtual int64_t Cardinality() const { return kUnknownCardinality; } }; // Returns the runner threadpool size from an OpKernelContext. diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index d806545d573bda..080e8b5c98f719 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -137,6 +137,25 @@ FunctionDef XTimesTwo() { }); } +FunctionDef XTimesTwoWithControlInput() { + const Tensor kTwo = test::AsScalar(2); + return FDH::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"scale"}, {{"T", "$T"}}, /*dep=*/{"x"}}, + }); +} + FunctionDef TwoDeviceMult() { const Tensor kTwo = test::AsScalar(2); const Tensor kThree = test::AsScalar(3); diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 559e0d6d67d241..b0ce4fafde58af 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -70,6 +70,8 @@ GraphDef GDef(gtl::ArraySlice nodes, // x: T -> x * 2. FunctionDef XTimesTwo(); +// Same as `XTimesTwo` above, but with the `x` input as a control dependency. +FunctionDef XTimesTwoWithControlInput(); // x: T -> cpu(x * 2) + cpu(x * 3). FunctionDef TwoDeviceTimesFive(); diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index 104af6dcfb1936..e6f94a8e444b4a 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -266,6 +266,12 @@ auto* tf_data_model_gauge = tsl::monitoring::Gauge, 1>::New( "/tensorflow/data/model", "tf.data autotuning model proto.", "id"); +auto* tf_data_pipeline_processing_time = tsl::monitoring::Gauge::New( + "/tensorflow/data/pipeline_processing_time", + "The total processing time of the slowest stage in the input pipeline " + "in microseconds", + "id"); + auto* tf_data_auto_shard = tsl::monitoring::Gauge::New( "/tensorflow/data/autoshard", "tf.data autoshard statistics.", "id", "name"); @@ -467,6 +473,11 @@ tsl::monitoring::GaugeCell>* GetTFDataModelGauge( return tf_data_model_gauge->GetCell(id); } +tsl::monitoring::GaugeCell* GetTFDataPipelineProcessingTimeGauge( + const string& id) { + return tf_data_pipeline_processing_time->GetCell(id); +} + void RecordTFDataBytesFetched(int64_t num_bytes) { tf_data_bytes_fetched_counter->GetCell()->IncrementBy(num_bytes); } @@ -822,6 +833,11 @@ void RecordUnusedOutput(const string& op_name) { graph_unused_outputs->GetCell(op_name)->IncrementBy(1); } +void RecordPipelineProcessingTime(const string& id, + double pipeline_processing_time_usec) { + GetTFDataPipelineProcessingTimeGauge(id)->Set(pipeline_processing_time_usec); +} + void IncrementTestCounter(const string& name, const string& label) { test_counters->GetCell(name, label)->IncrementBy(1); } diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index 5b15ee16b0a165..bcc5808cf7f8e8 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -243,6 +243,10 @@ void UpdateGraphPendingQueueLength(uint64 len); // Records that one output of an op of type `op_name` was unused. void RecordUnusedOutput(const string& op_name); +// Records the pipeline processing time in microseconds +void RecordPipelineProcessingTime(const string& id, + double pipeline_processing_time_usec); + // Updates the metrics stored about time spent building graphs. // // By "GraphBuild", we refer to building a client graph, which is a sub-graph of diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 7cfbd96294136d..17bf8d731236b8 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -18,12 +18,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/time/clock.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/model.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -549,7 +551,7 @@ class InterleaveMany : public Node { self_processing_time + inputs_processing_time; } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::INTERLEAVE_MANY); return OkStatus(); @@ -761,7 +763,7 @@ class AsyncInterleaveMany : public Node { self_processing_time + inputs_processing_time; } - double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) { + double MaximumBufferedBytes() const override TF_SHARED_LOCKS_REQUIRED(mu_) { auto* parameter = gtl::FindOrNull(parameters_, kMaxBufferedElements); if (parameter == nullptr) { parameter = gtl::FindOrNull(parameters_, kParallelism); @@ -772,7 +774,7 @@ class AsyncInterleaveMany : public Node { return (*parameter)->value * AverageBufferedElementSizeLocked(); } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY); return OkStatus(); @@ -864,7 +866,7 @@ class KnownRatio : public Node { self_processing_time + inputs_processing_time; } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::KNOWN_RATIO); node_proto->set_ratio(ratio_); @@ -1243,7 +1245,7 @@ class UnknownRatio : public Node { self_processing_time + inputs_processing_time; } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN_RATIO); return OkStatus(); @@ -1297,7 +1299,7 @@ class Unknown : public Node { TotalProcessingTimeForInputs(*total_processing_times); } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN); return OkStatus(); @@ -1326,7 +1328,7 @@ class AsyncKnownRatio : public AsyncRatio { parameters); } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO); node_proto->set_ratio(Ratio()); @@ -1371,7 +1373,7 @@ class AsyncUnknownRatio : public AsyncRatio { Args{id_, name_, std::move(output)}, parameters); } - Status ToProto(ModelProto::Node* node_proto) const { + Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_UNKNOWN_RATIO); return OkStatus(); @@ -2205,8 +2207,9 @@ Status Node::FromProto(ModelProto::Node node_proto, Model::Model() : optimization_period_ms_(kOptimizationPeriodMinMs), safe_to_collect_metrics_(std::make_shared(true)) { - model_gauge_cell_ = metrics::GetTFDataModelGauge( - strings::StrCat(reinterpret_cast(this))); + model_id_ = strings::StrCat(reinterpret_cast(this)); + model_gauge_cell_ = metrics::GetTFDataModelGauge(model_id_); + // Capture `safe_to_collect_metrics_` by value to avoid use-after-free issues // when the callback is invoked after the model has been destroyed. model_gauge_cell_->Set( @@ -2237,6 +2240,8 @@ Model::Model() Model::~Model() { mutex_lock l(safe_to_collect_metrics_->mu); safe_to_collect_metrics_->val = false; + // Reset the pipeline processing time to 0 + metrics::RecordPipelineProcessingTime(model_id_, 0); } void Model::AddNode(Node::Factory factory, const string& name, @@ -2356,6 +2361,38 @@ void Model::Optimize(AutotuneAlgorithm algorithm, mutex_lock l(mu_); snapshot_ = snapshot; optimization_params_ = optimization_params; + + if (snapshot_) { + double pipeline_processing_usec = 0; + ModelTiming model_timing(snapshot_); + auto bfs_stage_roots = model_timing.GetStageRoots(); + for (const auto& root : bfs_stage_roots) { + auto* root_timing = model_timing.GetTiming(root.get()); + if (root_timing == nullptr) { + constexpr int TEN_MINUTES = 60 * 10; + LOG_EVERY_N_SEC(ERROR, TEN_MINUTES) + << "Encounter an error when computing the pipeline processing " + "time for " + "/tensorflow/data/pipeline_processing_time"; + pipeline_processing_usec = 0; + break; + } + + double root_total_time_usec = root_timing->total_time_nsec * + root_timing->pipeline_ratio / + EnvTime::kMicrosToNanos; + + pipeline_processing_usec = + std::max(pipeline_processing_usec, root_total_time_usec); + } + // Only updates the pipeline processing time when it is greater than 0. + // If it is zero, we assume the pipeline processing time is the same + // as the previous one and do not update it. + if (pipeline_processing_usec > 0) { + metrics::RecordPipelineProcessingTime(model_id_, + pipeline_processing_usec); + } + } } } diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index a7b95dfe5cadd0..505fc2f8a5e3f2 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -1186,6 +1186,8 @@ class Model { std::shared_ptr snapshot_ TF_GUARDED_BY(mu_); // Stores the optimization parameters used by autotune. OptimizationParams optimization_params_ TF_GUARDED_BY(mu_); + // Stores the model id in the string format + std::string model_id_; }; // Class to compute timing information for a model. diff --git a/tensorflow/core/framework/op_requires.h b/tensorflow/core/framework/op_requires.h index c5fb7796ecf6e8..a009a11ea606e7 100644 --- a/tensorflow/core/framework/op_requires.h +++ b/tensorflow/core/framework/op_requires.h @@ -49,7 +49,7 @@ namespace tensorflow { #define OP_REQUIRES_OK(CTX, ...) \ do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ + const ::tensorflow::Status& _s(__VA_ARGS__); \ if (!TF_PREDICT_TRUE(_s.ok())) { \ CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ diff --git a/tensorflow/core/framework/tensor_matcher.h b/tensorflow/core/framework/tensor_matcher.h index 094d66f81f72f3..e89cfc15cd1f2a 100644 --- a/tensorflow/core/framework/tensor_matcher.h +++ b/tensorflow/core/framework/tensor_matcher.h @@ -34,7 +34,7 @@ namespace test { // // Use this like: // -// EXPECT_EQ(lhs, TensorEq(rhs)); +// EXPECT_THAT(lhs, TensorEq(rhs)); // // All POD types and DT_STRING type tensors are supported. Note that this // utility requires Tensors to point to CPU memory. diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index c4e23a8d07ba5e..152e0538f81bfe 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/strcat.h" diff --git a/tensorflow/core/function/capture/BUILD b/tensorflow/core/function/capture/BUILD index 9588fe68bdf6ee..31ecbe1a79c16f 100644 --- a/tensorflow/core/function/capture/BUILD +++ b/tensorflow/core/function/capture/BUILD @@ -34,9 +34,10 @@ py_strict_test( ], deps = [ ":free_vars_detect", - "//tensorflow/python/util:tf_decorator_py", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//third_party/py/numpy", + "//tensorflow/python/util:tf_decorator_py", ], ) @@ -45,13 +46,14 @@ py_strict_test( srcs = ["by_ref_capture_test.py"], python_version = "PY3", deps = [ + "@absl_py//absl/testing:parameterized", + #internal proto upb dep "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", "//tensorflow/python/platform:client_testlib", - "@absl_py//absl/testing:parameterized", ], ) @@ -76,9 +78,10 @@ py_strict_test( python_version = "PY3", deps = [ ":capture_container", + "@absl_py//absl/testing:parameterized", + #internal proto upb dep "//tensorflow/core/function/trace_type", "//tensorflow/python/platform:client_testlib", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/core/function/integration_test/BUILD b/tensorflow/core/function/integration_test/BUILD index 5c6f6d189d9e25..e78adb5486088d 100644 --- a/tensorflow/core/function/integration_test/BUILD +++ b/tensorflow/core/function/integration_test/BUILD @@ -12,8 +12,9 @@ py_strict_test( srcs = ["side_inputs_test.py"], python_version = "PY3", deps = [ - "//tensorflow:tensorflow_py", "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//tensorflow:tensorflow_py", ], ) @@ -22,7 +23,8 @@ py_strict_test( srcs = ["side_inputs_manual_api_test.py"], python_version = "PY3", deps = [ - "//tensorflow:tensorflow_py", "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/core/function/polymorphism/BUILD b/tensorflow/core/function/polymorphism/BUILD index 478ca86c222332..67165dfc34deca 100644 --- a/tensorflow/core/function/polymorphism/BUILD +++ b/tensorflow/core/function/polymorphism/BUILD @@ -32,6 +32,7 @@ py_strict_test( deps = [ ":function_type", ":type_dispatch", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", "//tensorflow/python/types:trace", ], @@ -60,6 +61,7 @@ py_strict_test( visibility = ["//learning/brain/contrib/eager/python/examples:__pkg__"], deps = [ ":function_cache", + #internal proto upb dep "//tensorflow/core/function/polymorphism:function_type", "//tensorflow/core/function/trace_type", "//tensorflow/python/ops:array_ops", @@ -116,6 +118,7 @@ py_strict_test( python_version = "PY3", deps = [ ":function_type", + #internal proto upb dep "//tensorflow/core/function/polymorphism:function_type_proto_py", "//tensorflow/core/function/trace_type", "//tensorflow/core/function/trace_type:serialization", diff --git a/tensorflow/core/function/runtime_client/BUILD b/tensorflow/core/function/runtime_client/BUILD index 2ad85369234ca7..d1046cdd72e9cd 100644 --- a/tensorflow/core/function/runtime_client/BUILD +++ b/tensorflow/core/function/runtime_client/BUILD @@ -19,13 +19,16 @@ cc_library( hdrs = [ "runtime_client.h", ], + defines = select({ + "//tensorflow/compiler/mlir/python:disable_mlir_config": ["DISABLE_MLIR"], + "//conditions:default": [], + }), visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/compiler/mlir/python:mlir", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:export_graphdef", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -52,7 +55,12 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - ], + ] + select({ + "//tensorflow/compiler/mlir/python:disable_mlir_config": [], + "//conditions:default": [ + "//tensorflow/compiler/mlir/python:mlir", + ], + }), # TODO(mdan): Get rid of alwayslink, it's nonstandard. alwayslink = 1, ) @@ -156,6 +164,7 @@ py_strict_test( tags = ["no_oss"], # TODO(b/219089812) deps = [ ":runtime_client_py", + #internal proto upb dep "//tensorflow/core/framework:function_proto_py", "//tensorflow/core/function/testing:test_pass_py", "//tensorflow/python:tf2", diff --git a/tensorflow/core/function/runtime_client/runtime_client.cc b/tensorflow/core/function/runtime_client/runtime_client.cc index b10bcc3856e19b..6438a1ca2b83c1 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.cc +++ b/tensorflow/core/function/runtime_client/runtime_client.cc @@ -31,7 +31,11 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" + +#if !defined(DISABLE_MLIR) #include "tensorflow/compiler/mlir/python/mlir.h" +#endif + #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD index a88f0a3aca20d6..b3c0f2e05f8fa6 100644 --- a/tensorflow/core/function/trace_type/BUILD +++ b/tensorflow/core/function/trace_type/BUILD @@ -51,6 +51,8 @@ py_strict_test( ":custom_nest_trace_type", ":default_types", ":trace_type", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", @@ -66,7 +68,6 @@ py_strict_test( "//tensorflow/python/ops:variables", "//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) @@ -94,6 +95,7 @@ py_strict_test( deps = [ ":default_types", ":serialization", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", "//tensorflow/python/types:trace", "@absl_py//absl/testing:parameterized", @@ -121,6 +123,7 @@ py_strict_test( deps = [ ":custom_nest_trace_type", ":default_types", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", "//tensorflow/python/types:trace", "@absl_py//absl/testing:parameterized", @@ -143,6 +146,7 @@ py_strict_test( python_version = "PY3", deps = [ ":serialization", + #internal proto upb dep "//tensorflow/core/function/trace_type:serialization_test_proto_py", "//tensorflow/python/platform:client_testlib", ], diff --git a/tensorflow/core/function/trace_type/trace_type_test.py b/tensorflow/core/function/trace_type/trace_type_test.py index 0ef6e8d8d75adc..3e9c7dbe06b05a 100644 --- a/tensorflow/core/function/trace_type/trace_type_test.py +++ b/tensorflow/core/function/trace_type/trace_type_test.py @@ -439,29 +439,29 @@ def testDictofTensorSpecs(self): class TraceTypeMemoryTest(test.TestCase): - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testGeneric(self): trace_type.from_value(1) trace_type.from_value(DummyGenericClass()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTensor(self): tensor = array_ops.zeros([10]) trace_type.from_value(tensor) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTuple(self): trace_type.from_value((1, 2, 3)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDict(self): trace_type.from_value({1: 1, 2: 2, 3: 3}) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testList(self): trace_type.from_value([1, 2, 3]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAttrs(self): trace_type.from_value(TestAttrsClass(1, 2)) diff --git a/tensorflow/core/function/transform/BUILD b/tensorflow/core/function/transform/BUILD index 91e7302db8229a..9d7b60dd1328e7 100644 --- a/tensorflow/core/function/transform/BUILD +++ b/tensorflow/core/function/transform/BUILD @@ -43,6 +43,8 @@ py_strict_test( tags = ["no_oss"], # TODO(b/219089812) deps = [ ":transform", + "@absl_py//absl/testing:parameterized", + #internal proto upb dep "//tensorflow/core/function/testing:test_pass_py", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", @@ -57,6 +59,5 @@ py_strict_test( "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:save", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 4342ec637492ec..84b33460db1b03 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -841,7 +841,6 @@ class SymbolicShapeRefiner { } int output_port_num = input_tensor.index(); - AttrValue attr_output_shape; TensorShapeProto proto; const auto handle = input_ic->output(output_port_num); input_ic->ShapeHandleToProto(handle, &proto); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index ecd559734ea870..8e79043af832a6 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -910,7 +910,7 @@ tf_kernel_library( tf_cuda_cc_test( name = "remapper_test", srcs = ["remapper_test.cc"], - tags = ["no_rocm"], + tags = [], deps = [ ":remapper", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 2eed9cd40061a9..689185fb08923d 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -932,7 +932,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { EXPECT_EQ(tensors.size(), tensors_expected.size()); EXPECT_EQ(tensors.size(), item.fetch.size()); for (int i = 0; i < item.fetch.size(); ++i) { - test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4); + test::ExpectClose(tensors_expected[i], tensors[i], -1, 4e-4); } } diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 5519abcf25aaa9..4edc6cc529451a 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -373,6 +373,7 @@ cc_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:functional_ops", + "@com_google_absl//absl/strings", ] + tf_protos_all(), ) @@ -389,6 +390,8 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/platform:types", + "@com_google_absl//absl/strings", ] + tf_protos_all(), ) @@ -688,6 +691,7 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", ] + tf_protos_all(), alwayslink = 1, ) @@ -708,6 +712,7 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/platform:status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:errors", ], diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index 0cd0db36808485..62c615c45905fc 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -31,15 +34,32 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace grappler { namespace fusion_utils { namespace { + +// See the comment for the proto field `tensorflow.NodeDef.input`. +constexpr char kControlInputPrefix[] = "^"; + +bool IsControlInput(const string& node_input) { + return absl::StartsWith(node_input, kControlInputPrefix); +} + +string StripControlInputNotation(const string& node_input) { + return string(absl::StripPrefix(node_input, kControlInputPrefix)); +} + +string AddControlInputNotation(const string& node_input) { + return absl::StrCat(kControlInputPrefix, node_input); +} + +// Returns e.g. `"node"` given `"node:out"` or `"node:out:0"`. See the comment +// for the proto field `tensorflow.FunctionDef.node_def`. string ParseNodeConnection(const string& name) { - // If input/output node name has semicolon, take the prefix. Otherwise take - // the whole string. return name.substr(0, name.find(':')); } @@ -194,10 +214,15 @@ OpDef GetUniqueSignature(const OpDef& first_signature, for (NodeDef& function_node : *nodes_to_fuse) { for (auto& node_input : *function_node.mutable_input()) { - const auto& input = ParseNodeConnection(node_input); + bool is_control_input = IsControlInput(node_input); + const auto& input = + ParseNodeConnection(StripControlInputNotation(node_input)); if (const string* new_name = gtl::FindOrNull(changed_input_names, input)) { node_input = *new_name + ParseOutputNode(node_input); + if (is_control_input) { + node_input = AddControlInputNotation(node_input); + } } } } @@ -215,7 +240,9 @@ void FuseFunctionNodes(const StringCollection& first_inputs, protobuf::RepeatedPtrField* nodes_to_fuse) { for (NodeDef& function_node : *nodes_to_fuse) { for (auto& node_input : *function_node.mutable_input()) { - auto parsed_name = ParseNodeConnection(node_input); + bool is_control_input = IsControlInput(node_input); + auto parsed_name = + ParseNodeConnection(StripControlInputNotation(node_input)); auto input_it = std::find(second_inputs.begin(), second_inputs.end(), parsed_name); @@ -224,6 +251,9 @@ void FuseFunctionNodes(const StringCollection& first_inputs, auto arg_num = std::distance(second_inputs.begin(), input_it); node_input = set_input(first_inputs, second_inputs, first_outputs, arg_num); + if (is_control_input) { + node_input = AddControlInputNotation(node_input); + } } } } diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc index e667affeeaf7f4..84c22590926a5b 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc @@ -15,38 +15,39 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace grappler { namespace fusion_utils { namespace { -string ParseNodeConnection(const string &name) { +string ParseNodeConnection(const string& name) { return name.substr(0, name.find(':')); } -void CheckUniqueNames(const FunctionDef &function) { +void CheckUniqueNames(const FunctionDef& function) { std::unordered_set inputs; - for (const auto &input_arg : function.signature().input_arg()) + for (const auto& input_arg : function.signature().input_arg()) inputs.insert(input_arg.name()); EXPECT_EQ(inputs.size(), function.signature().input_arg_size()); std::unordered_set outputs; - for (const auto &output_arg : function.signature().output_arg()) + for (const auto& output_arg : function.signature().output_arg()) outputs.insert(output_arg.name()); EXPECT_EQ(outputs.size(), function.signature().output_arg_size()); std::unordered_set nodes; - for (const auto &node : function.node_def()) nodes.insert(node.name()); + for (const auto& node : function.node_def()) nodes.insert(node.name()); EXPECT_EQ(nodes.size(), function.node_def_size()); } @@ -71,7 +72,7 @@ TEST(FusionUtilsTest, FuseFunctionsByComposition) { CheckUniqueNames(*fused_function); const NodeDef *parent_mul = nullptr, *output_mul = nullptr; - for (const auto &fused_node : fused_function->node_def()) { + for (const auto& fused_node : fused_function->node_def()) { if (fused_node.op() == "Mul") { if (fused_node.name() == "y") parent_mul = &fused_node; @@ -89,6 +90,44 @@ TEST(FusionUtilsTest, FuseFunctionsByComposition) { EXPECT_EQ(ParseNodeConnection(output_value), output_mul->name()); } +TEST(FusionUtilsTest, FuseFunctionsWithControlInputs) { + GraphDef graph; + auto *parent_function = graph.mutable_library()->add_function(); + *parent_function = test::function::XTimesTwoWithControlInput(); + auto *function = graph.mutable_library()->add_function(); + *function = test::function::XTimesTwoWithControlInput(); + + auto *fused_function = FuseFunctions( + *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature, + fusion_utils::ComposeInput, fusion_utils::ComposeOutput, + fusion_utils::MergeNodes, graph.mutable_library()); + + EXPECT_EQ(fused_function->signature().name(), "fused_maps"); + EXPECT_EQ(fused_function->signature().input_arg_size(), 1); + EXPECT_EQ(fused_function->signature().output_arg_size(), 1); + EXPECT_EQ(fused_function->ret_size(), 1); + CheckUniqueNames(*fused_function); + + const NodeDef *parent_mul = nullptr, *output_mul = nullptr; + for (const auto& fused_node : fused_function->node_def()) { + if (fused_node.op() == "Mul") { + if (fused_node.name() == "y") + parent_mul = &fused_node; + else + output_mul = &fused_node; + } + } + ASSERT_NE(parent_mul, nullptr); + ASSERT_NE(output_mul, nullptr); + EXPECT_EQ(ParseNodeConnection(output_mul->input(1)), + absl::StrCat("^", parent_mul->name())); + + auto output_value = fused_function->ret().at( + fused_function->signature().output_arg(0).name()); + + EXPECT_EQ(ParseNodeConnection(output_value), output_mul->name()); +} + TEST(FusionUtilsTest, FuseFunctionWithPredicate) { GraphDef graph; auto *xtimes_two = graph.mutable_library()->add_function(); @@ -112,7 +151,7 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) { ASSERT_TRUE( function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function)); - const auto &equal_node = fused_function->node_def( + const auto& equal_node = fused_function->node_def( function_utils::FindFunctionNodeWithOp("Equal", *fused_function)); EXPECT_EQ(xtimes_two->signature().output_arg(0).name(), @@ -152,8 +191,8 @@ TEST(FusionUtilsTest, ZipFusion) { auto *function = graph.mutable_library()->add_function(); *function = test::function::XTimesTwo(); - auto zip_signature = [](const OpDef &parent_function_signature, - const OpDef &function_signature, + auto zip_signature = [](const OpDef& parent_function_signature, + const OpDef& function_signature, OpDef *fused_function_signature) { *fused_function_signature = parent_function_signature; fused_function_signature->mutable_input_arg()->MergeFrom( @@ -162,9 +201,9 @@ TEST(FusionUtilsTest, ZipFusion) { function_signature.output_arg()); }; - auto zip_input = [](const StringCollection &parent_inputs, - const StringCollection &function_inputs, - const StringCollection &parent_outputs, int arg_num) { + auto zip_input = [](const StringCollection& parent_inputs, + const StringCollection& function_inputs, + const StringCollection& parent_outputs, int arg_num) { // Take corresponding parent output. return function_inputs.at(arg_num); }; diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index cfd6826ee8ccce..a537b794760c65 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -79,13 +80,30 @@ bool SameDeterministicAttr(const NodeDef& parallel_map_node, return false; } +// Returns a name for a new node or function that fuses the inputs. +// - For nodes, this is only for debugging. +// - For functions, this additionally prevents collisions (upstream of this +// optimizer, the act of optimizing a single graph entails individually +// optimizing each function in that graph and later aggregating any new +// functions introduced during these individual optimizations into that single +// graph's collective function library). +// TODO(mpcallanan): Look at deduping names in a more generic fashion upstream. +string GetFusedName(const NodeDef& parent, const NodeDef& child) { + return absl::StrCat("map_fusion_nodes/", parent.name(), "/", child.name()); +} +string GetFusedName(const FunctionDef& parent, const FunctionDef& child) { + return absl::StrCat("map_fusion_funcs/", parent.signature().name(), "/", + child.signature().name()); +} + // Sets basic function parameters and copies attributes from parent and map // node. NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node); + graph_utils::SetUniqueGraphNodeName(GetFusedName(parent_map_node, map_node), + graph->graph(), &fused_node); if (map_node.op() == kMapDatasetOp) { fused_node.set_op(kMapDatasetOp); @@ -185,9 +203,10 @@ Status MapFusion::OptimizeAndCollectStats(Cluster* cluster, return nullptr; } return fusion_utils::FuseFunctions( - *parent_func, *func, "fused_map", fusion_utils::ComposeSignature, - fusion_utils::ComposeInput, fusion_utils::ComposeOutput, - fusion_utils::MergeNodes, output->mutable_library()); + *parent_func, *func, GetFusedName(*parent_func, *func), + fusion_utils::ComposeSignature, fusion_utils::ComposeInput, + fusion_utils::ComposeOutput, fusion_utils::MergeNodes, + output->mutable_library()); }; for (const NodeDef& node : sorted_old_graph.node()) { diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc index 76b45248c3f17c..c81191ecd823df 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -186,6 +187,55 @@ TEST(MapFusionTest, FuseTwoParallelMapNodesIntoOne) { EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); } +TEST(MapFusionTest, FusedNodesAndFunctionsAreNamedAfterOldNodesAndFunctions) { + using test::function::NDef; + NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper( + "num_parallel_calls", DT_INT64, + [](TensorProto* proto) { proto->add_int64_val(-1); }); + auto graph = [&num_parallel_calls_node]( + const std::string& parent_map_node_name, + const std::string& map_node_name, + const std::string& parent_function_name, + const std::string& function_name) { + FunctionDef parent_fn = test::function::XTimesTwo(); + FunctionDef fn = test::function::XTimesTwo(); + parent_fn.mutable_signature()->set_name(parent_function_name); + fn.mutable_signature()->set_name(function_name); + return test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + num_parallel_calls_node, + MakeParallelMapV2Node(parent_map_node_name, "range", + num_parallel_calls_node.name(), + parent_function_name, "default"), + MakeParallelMapV2Node(map_node_name, parent_map_node_name, + num_parallel_calls_node.name(), function_name, + "default")}, + // FunctionLib + {parent_fn, fn}); + }; + + GrapplerItem item_1; + item_1.graph = graph("map1", "map2", "fnA", "fnB"); + GraphDef output_1; + TF_ASSERT_OK(OptimizeWithMapFusion(item_1, &output_1, true)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName( + "map_fusion_nodes/map1/map2", output_1)); + EXPECT_TRUE(graph_utils::ContainsGraphFunctionWithName( + "map_fusion_funcs/fnA/fnB", output_1.library())); + + GrapplerItem item_2; + item_2.graph = graph("map3", "map4", "fnC", "fnD"); + GraphDef output_2; + TF_ASSERT_OK(OptimizeWithMapFusion(item_2, &output_2, true)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName( + "map_fusion_nodes/map3/map4", output_2)); + EXPECT_TRUE(graph_utils::ContainsGraphFunctionWithName( + "map_fusion_funcs/fnC/fnD", output_2.library())); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index 2896e3e703caa5..98caa6218376e3 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -1551,6 +1551,142 @@ TEST_F(MklFuseInstanceNormTest, FuseMklInstanceNormWithActivation4D_FP32_NCHW) { FuseMklInstanceNorm4D("NCHW", true); } +class FusedConvBiasAddAndHardSwishTest : public GrapplerTest { + public: + const string kAddOp = "Add"; + const string kAddV2Op = "AddV2"; + + template + void RunTest(const string& add_op, const bool is_depthwise) { + using ::tensorflow::ops::Placeholder; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); + auto bias_shape = ops::Placeholder::Shape({is_depthwise ? 384 : 128}); + + auto input = Placeholder(s.WithOpName("input"), DType, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DType, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DType, bias_shape); + const DataType const_dt = with_cast_op ? DT_FLOAT : DType; + typedef typename EnumToDataType::Type DT; + Tensor three(const_dt, TensorShape({})); + Tensor one_sixth(const_dt, TensorShape({})); + three.scalar
()() = static_cast
(3.0f); + one_sixth.scalar
()() = static_cast
(1.0f / 6.0f); + auto three_op = + with_cast_op + ? ops::Cast(s.WithOpName("three"), Input::Initializer(three), + DT_BFLOAT16) + : ops::Const(s.WithOpName("three"), Input::Initializer(three)); + auto one_sixth_op = + with_cast_op ? ops::Cast(s.WithOpName("one_sixth"), + Input::Initializer(one_sixth), DT_BFLOAT16) + : ops::Const(s.WithOpName("one_sixth"), + Input::Initializer(one_sixth)); + + std::vector strides = {1, 1, 1, 1}; + Output conv; + if (is_depthwise) { + conv = ops::DepthwiseConv2dNative( + s.WithOpName("conv"), input, filter, strides, "SAME", + ops::DepthwiseConv2dNative::Attrs().DataFormat("NHWC")); + } else { + conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME", + ops::Conv2D::Attrs().DataFormat("NHWC")); + } + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias, + ops::BiasAdd::Attrs().DataFormat("NHWC")); + + Output add; + if (add_op == kAddV2Op) { + add = ops::AddV2(s.WithOpName(add_op), three_op, bias_add); + } else { + add = ops::Add(s.WithOpName(add_op), three_op, bias_add); + } + + auto relu6 = ops::Relu6(s.WithOpName("relu_6"), add); + auto mul_one_sixth = + ops::Mul(s.WithOpName("mul_one_sixth"), one_sixth_op, bias_add); + auto mul_output = ops::Mul(s.WithOpName("output"), mul_one_sixth, relu6); + + auto fetch = ops::Identity(s.WithOpName("fetch"), mul_output); + + auto input_tensor = GenerateTensorWithSetRandom( + TensorShape(input_shape.shape_.dim_sizes())); + auto filter_tensor = GenerateTensorWithSetRandom( + TensorShape(filter_shape.shape_.dim_sizes())); + auto bias_tensor = GenerateTensorWithSetRandom( + TensorShape(bias_shape.shape_.dim_sizes())); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_tensor}, + {"filter", filter_tensor}, + {"bias", bias_tensor}}; + + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "output") { + if (is_depthwise) { + EXPECT_EQ("_FusedDepthwiseConv2dNative", node.op()); + } else { + EXPECT_EQ("_FusedConv2D", node.op()); + } + EXPECT_EQ("input", node.input(0)); + EXPECT_EQ("filter", node.input(1)); + EXPECT_EQ("bias", node.input(2)); + EXPECT_EQ(1, node.attr().at("num_args").i()); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + EXPECT_EQ(2, fused_ops.size()); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("_FusedHardSwish", fused_ops[1]); + found++; + } + } + EXPECT_EQ(1, found); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectClose(tensors_expected[0], tensors[0], 1e-6); + } +}; + +TEST_F(FusedConvBiasAddAndHardSwishTest, Float32Conv2DBiasHardSwish) { + RunTest("AddV2", false); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Float32DWConv2DBiasHardSwish) { + RunTest("AddV2", true); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Bfloat16Conv2DBiasHardSwish) { + RunTest("Add", false); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Bfloat16DWConv2DBiasHardSwish) { + RunTest("Add", true); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, Bfloat16Conv2DBiasHardSwishWithCast) { + RunTest("Add", false); +} +TEST_F(FusedConvBiasAddAndHardSwishTest, + Bfloat16DWConv2DBiasHardSwishWithCast) { + RunTest("Add", true); +} + } // namespace grappler } // namespace tensorflow #endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 3c37150f496aa4..3723290691ce28 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -1060,7 +1060,8 @@ bool FindConv2DWithBatchNormAndActivation( bool FindContractionWithBiasInPort(const RemapperContext& ctx, const utils::MutableNodeView& add_node_view, const NodeDef& add_node_def, int port_id, - ContractionWithBiasAdd* base) { + ContractionWithBiasAdd* base, + const int allowed_fanouts = 1) { // Input to AddN must match ContractionWithBiasAdd pattern. if (add_node_view.NumRegularFanins() < port_id + 1) return false; const auto& bias_add_node_view = @@ -1071,7 +1072,7 @@ bool FindContractionWithBiasInPort(const RemapperContext& ctx, if (!FindContractionWithBias(ctx, bias_add_node_view->node_index(), base, /*check_device_compatible=*/false)) return false; - if (!HasAtMostOneFanoutAtPort0(*bias_add_node_view) || + if (bias_add_node_view->GetRegularFanout(0).size() > allowed_fanouts || !HaveSameDataType(&add_node_def, bias_add_node_def) || IsInPreserveSet(ctx, bias_add_node_def)) return false; @@ -2670,6 +2671,140 @@ bool FindTensorToHashBucket(const RemapperContext& ctx, int node_index, return true; } +// clang-format off +// HardSwish pattern +// input Const (value: 3) +// | \ / +// | Add or AddV2 +// | | +// | Relu6 +// | / +// | / +// Const (value: 0.1666) | / +// \ | / +// Mul / +// \ / +// Mul +// clang-format on +bool FindHardSwish(RemapperContext& ctx, int node_index, + std::map* matched_nodes_map, + std::set* remove_node_indices) { + if (!IsMKLEnabled()) return false; + + using utils::MatchingDirection; + using utils::NodeStatus; + // clang-format off + utils::OpTypePattern pattern {"Mul", "output", NodeStatus::kReplace, + { + {"Mul", "mul_one_sixth", NodeStatus::kRemove, + { + {"Const|Cast", "one_sixth", NodeStatus::kRemain}, + {"*", "input", NodeStatus::kRemain} + } + }, + {"Relu6", "relu6", NodeStatus::kRemove, + { + {"Add|AddV2", "add", NodeStatus::kRemove, + { + {"*", "input", NodeStatus::kRemain}, + {"Const|Cast", "three", NodeStatus::kRemain} + } + } + } + }, + } + }; + // clang-format on + bool found_match = false; + utils::SubGraphMatcher graph_matcher( + &(ctx.graph_view)); + + matched_nodes_map->clear(); + remove_node_indices->clear(); + + found_match = graph_matcher.GetMatchedNodes( + pattern, ctx.nodes_to_preserve, ctx.graph_view.GetNode(node_index), + matched_nodes_map, remove_node_indices); + + if (found_match) { + // Check if the values of Const nodes are as expected + std::map values_map = {{"three", 3.0}, + {"one_sixth", 0.16666}}; + if (!VerifyConstants(&ctx, matched_nodes_map, &values_map)) return false; + } + + return found_match; +} + +// clang-format off +// Contraction + BiasAdd + _FusedHardSwish activation +// input filter +// \ / +// Contraction bias +// | / +// BiasAdd +// | +// _FusedHardSwish +// clang-format on +bool FindContractionWithBiasAddAndHardSwish( + RemapperContext& ctx, int node_index, + std::map* matched_nodes_map, + std::set* remove_node_indices) { + if (!IsMKLEnabled()) return false; + + const auto* node_view = ctx.graph_view.GetNode(node_index); + if (HasControlFaninOrFanout(*node_view)) return false; + + // Check if HardSwish pattern is available + if (!FindHardSwish(ctx, node_index, matched_nodes_map, remove_node_indices)) + return false; + // Get handle of Add|AddV2 op that is the root of HardSwish pattern. + const auto* add_node_view = + ctx.graph_view.GetNode(matched_nodes_map->at("add")); + const auto* add_node_def = add_node_view->node(); + + // Check if ContractionWithBias pattern is feeding HardSwish + ContractionWithBiasAdd base; + int port_id = 0; + // BiasAdd node is expected to have 2 fanouts feeding the HardSwish pattern. + if (!FindContractionWithBiasInPort(ctx, *add_node_view, *add_node_def, + port_id, &base, /*allowed_fanouts*/ 2)) { + port_id = 1; + if (!FindContractionWithBiasInPort(ctx, *add_node_view, *add_node_def, + port_id, &base, /*allowed_fanouts*/ 2)) { + VLOG(2) << "Contraction + BiasAdd pattern was not found although" + << " HardSwish pattern was found, so fusion failed."; + return false; + } + } + + // Get the BiasAdd node + const auto* bias_node_def = ctx.graph_view.GetNode(base.bias_add)->node(); + if (!HaveSameDataType(add_node_def, bias_node_def)) return false; + + // Get the contraction node + const auto* contraction_node_view = ctx.graph_view.GetNode(base.contraction); + const auto* contraction_node_def = contraction_node_view->node(); + + // Currently only Conv2D and DepthwiseConv2D contraction ops are supported + if (!IsConv2D(*contraction_node_def) && + !IsDepthwiseConv2dNative(*contraction_node_def)) + return false; + + // Check if contraction is compatible with CPU + if (!IsCpuCompatibleConv2D(ctx, contraction_node_def) && + !IsCpuCompatibleDepthwiseConv2dNative(contraction_node_def)) + return false; + + // We found a {Conv2D, DepthwiseConv2D}+BiasAdd+_FusedHardSwish pattern. + matched_nodes_map->insert({"contraction", base.contraction}); + matched_nodes_map->insert({"bias_add", base.bias_add}); + + remove_node_indices->insert(base.contraction); + remove_node_indices->insert(base.bias_add); + return true; +} + bool FindFusedBatchMatMul(RemapperContext* ctx, int node_index, std::map* matched_nodes_map, std::set* remove_node_indices, @@ -3537,6 +3672,47 @@ Status AddFusedContractionNode( return OkStatus(); } +Status FuseContractionWithBiasAddAndHardSwish( + RemapperContext* ctx, std::map* matched_nodes_map, + std::set* remove_node_indices, std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { + auto* output_node = + ctx->graph_view.GetNode(matched_nodes_map->at("output"))->node(); + auto* contraction_node = + ctx->graph_view.GetNode(matched_nodes_map->at("contraction"))->node(); + auto* bias_add_node = + ctx->graph_view.GetNode(matched_nodes_map->at("bias_add"))->node(); + + bool is_conv2d = IsConv2D(*contraction_node); + + NodeDef fused_node; + fused_node.set_name(output_node->name()); + fused_node.set_op(is_conv2d ? kFusedConv2D : kFusedDepthwiseConv2dNative); + fused_node.set_device(contraction_node->device()); + fused_node.add_input(contraction_node->input(0)); + fused_node.add_input(contraction_node->input(1)); + fused_node.add_input(bias_add_node->input(1)); + + if (is_conv2d) { + CopyConv2DAttributes(*contraction_node, &fused_node); + } else { + CopyDepthwiseConv2dNativeAttributes(*contraction_node, &fused_node); + } + SetFusedOpAttributes(&fused_node, {"BiasAdd", "_FusedHardSwish"}); + + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + mutation->AddNode(std::move(fused_node), &status); + TF_RETURN_IF_ERROR(status); + TF_RETURN_IF_ERROR(mutation->Apply()); + (*invalidated_nodes)[matched_nodes_map->at("output")] = true; + + for (const auto& node_idx : *remove_node_indices) { + (*nodes_to_delete)[node_idx] = true; + } + return OkStatus(); +} + Status FuseConv2DSwish(RemapperContext* ctx, const std::map& matched_nodes_map, const std::set& remove_node_indices, @@ -4703,6 +4879,15 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, std::set remove_node_indices; std::vector input_node_names; + // Remap {Conv2D|DepthwiseConv2D} + BiasAdd + HardSwish subgraph + if (FindContractionWithBiasAddAndHardSwish(ctx, i, &matched_nodes_map, + &remove_node_indices)) { + TF_RETURN_IF_ERROR(FuseContractionWithBiasAddAndHardSwish( + &ctx, &matched_nodes_map, &remove_node_indices, &invalidated_nodes, + &nodes_to_delete)); + continue; + } + // Softplus + Tanh + Mul to Mish conversion matched_nodes_map.clear(); remove_node_indices.clear(); diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index d3a6652589381b..76c7098361d6f2 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -888,6 +888,163 @@ TEST_F(RemapperFuseConvWithBiasAndActivation, Conv3D_BF16) { RunTest<3, DT_BFLOAT16>(); } +class RemapperFuseConvWithBiasAndAddActivation : public RemapperTest { + public: + template + void RunTest() { + if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; + using ::tensorflow::ops::Placeholder; + + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = Placeholder::Shape({1, 1, 3, 128}); + auto bias_shape = Placeholder::Shape({128}); + auto add_shape = ops::Placeholder::Shape({8, 32, 32, 128}); + + auto input_t = GenerateRandomTensor({8, 32, 32, 3}); + auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); + auto bias_t = GenerateRandomTensor({128}); + auto add_t = GenerateRandomTensor({8, 32, 32, 128}); + + float leakyrelu_alpha = 0.5; + + std::vector strides = {1, 1, 1, 1}; + + if (dim == 3) { + input_shape = Placeholder::Shape({8, 4, 32, 32, 3}); + filter_shape = Placeholder::Shape({1, 1, 1, 3, 128}); + bias_shape = Placeholder::Shape({128}); + add_shape = ops::Placeholder::Shape({8, 4, 32, 32, 128}); + strides = {1, 1, 1, 1, 1}; + + input_t = GenerateRandomTensor({8, 4, 32, 32, 3}); + filter_t = GenerateRandomTensor({1, 1, 1, 3, 128}); + bias_t = GenerateRandomTensor({128}); + add_t = GenerateRandomTensor({8, 4, 32, 32, 128}); + } + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); + auto input_add = + Placeholder(s.WithOpName("input_add"), DT_FLOAT, add_shape); + + if (dim == 2) { + auto conv = + ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); + auto add = ops::Add(s.WithOpName("add_op"), input_add, bias_add); + + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); + + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, add)); + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); + return ops::Identity(fetch, + ops::internal::LeakyRelu(activate, add, attr)); + } + + return ops::Identity(fetch, bias); + }(); + } else if (dim == 3) { + auto conv = + ops::Conv3D(s.WithOpName("conv"), input, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); + auto add = ops::Add(s.WithOpName("add_op"), input_add, bias_add); + + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); + + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, add)); + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); + return ops::Identity(fetch, + ops::internal::LeakyRelu(activate, add, attr)); + } + + return ops::Identity(fetch, bias); + }(); + } + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, + {"filter", filter_t}, + {"bias", bias_t}, + {"input_add", add_t}}; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "activation") { + if (dim == 2) { + EXPECT_EQ(node.op(), "_FusedConv2D"); + } else if (dim == 3) { + EXPECT_EQ(node.op(), "_FusedConv3D"); + } + ASSERT_GE(node.input_size(), 3); + EXPECT_EQ(node.input(0), "input"); + EXPECT_EQ(node.input(1), "filter"); + + EXPECT_EQ(node.attr().at("num_args").i(), 2); + EXPECT_EQ(node.input(2), "bias"); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + ASSERT_EQ(fused_ops.size(), 3); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("Add", fused_ops[1]); + EXPECT_EQ(activation, fused_ops[2]); + found++; + } + } + EXPECT_EQ(found, 1); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + } + } +}; + +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv2D_F32) { + RunTest<2, DT_FLOAT>(); +} +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv3D_F32) { + RunTest<3, DT_FLOAT>(); +} +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv2D_BF16) { + RunTest<2, DT_BFLOAT16>(); +} +TEST_F(RemapperFuseConvWithBiasAndAddActivation, Conv3D_BF16) { + RunTest<3, DT_BFLOAT16>(); +} + class RemapperFuseConvWithSqueezeAndBias : public RemapperTest { public: template @@ -2255,102 +2412,6 @@ TEST_F(RemapperTest, FuseConv3DWithBiasAndAdd) { test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); } -TEST_F(RemapperTest, FuseConv3DWithBiasAndAddActivation) { - if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to oneDNN."; - using ::tensorflow::ops::Placeholder; - - for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - auto input_shape = Placeholder::Shape({8, 4, 32, 32, 3}); - auto filter_shape = Placeholder::Shape({1, 1, 1, 3, 128}); - auto bias_shape = Placeholder::Shape({128}); - auto add_shape = ops::Placeholder::Shape({8, 4, 32, 32, 128}); - - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); - auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); - auto input_add = - Placeholder(s.WithOpName("input_add"), DT_FLOAT, add_shape); - - float leakyrelu_alpha = 0.5; - - std::vector strides = {1, 1, 1, 1, 1}; - auto conv = - ops::Conv3D(s.WithOpName("conv"), input, filter, strides, "SAME"); - auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto add = ops::Add(s.WithOpName("add_op"), input_add, bias_add); - - ops::Identity fetch = [&]() -> ops::Identity { - auto activate = s.WithOpName("activation"); - auto fetch = s.WithOpName("fetch"); - - if (activation == "Relu") { - return ops::Identity(fetch, ops::Relu(activate, add)); - } else if (activation == "Relu6") { - return ops::Identity(fetch, ops::Relu6(activate, add)); - } else if (activation == "Elu") { - return ops::Identity(fetch, ops::Elu(activate, add)); - } else if (activation == "LeakyRelu") { - auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha); - return ops::Identity(fetch, - ops::internal::LeakyRelu(activate, add, attr)); - } - - return ops::Identity(fetch, bias); - }(); - - auto input_t = GenerateRandomTensor({8, 4, 32, 32, 3}); - auto filter_t = GenerateRandomTensor({1, 1, 1, 3, 128}); - auto bias_t = GenerateRandomTensor({128}); - auto add_t = GenerateRandomTensor({8, 4, 32, 32, 128}); - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, - {"filter", filter_t}, - {"bias", bias_t}, - {"input_add", add_t}}; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } - - Remapper optimizer(RewriterConfig::AGGRESSIVE); - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "activation") { - EXPECT_EQ(node.op(), "_FusedConv3D"); - ASSERT_GE(node.input_size(), 3); - EXPECT_EQ(node.input(0), "input"); - EXPECT_EQ(node.input(1), "filter"); - - EXPECT_EQ(node.attr().at("num_args").i(), 2); - EXPECT_EQ(node.input(2), "bias"); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - ASSERT_EQ(fused_ops.size(), 3); - EXPECT_EQ("BiasAdd", fused_ops[0]); - EXPECT_EQ("Add", fused_ops[1]); - EXPECT_EQ(activation, fused_ops[2]); - found++; - } - } - EXPECT_EQ(found, 1); - - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - ASSERT_EQ(tensors_expected.size(), 1); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); - } -} - // Conv2D + Add {6,} + Conv2D + Biasadd fusion. TEST_F(RemapperTest, FuseConv2DWithSemanticAdd) { if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to MKL."; diff --git a/tensorflow/core/ir/dialect.td b/tensorflow/core/ir/dialect.td index d80ecfa70cc19a..f3530af1232a16 100644 --- a/tensorflow/core/ir/dialect.td +++ b/tensorflow/core/ir/dialect.td @@ -172,7 +172,6 @@ def TFGraphDialect : Dialect { let useDefaultAttributePrinterParser = 1; let hasNonDefaultDestructor = 1; let hasOperationInterfaceFallback = 1; - let usePropertiesForAttributes = 0; } #endif // TFG_DIALECT diff --git a/tensorflow/core/ir/types/dialect.h b/tensorflow/core/ir/types/dialect.h index 419d396cb7aaf0..e2dc8bef70a5d1 100644 --- a/tensorflow/core/ir/types/dialect.h +++ b/tensorflow/core/ir/types/dialect.h @@ -111,12 +111,13 @@ class TensorFlowRefType : public TensorFlowType { // Define a class for each individual TensorFlow type (dtype), see types.def // for the list. -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ +#define HANDLE_TF_TYPE(tftype, enumerant, name_marg) \ class tftype##Type : public detail::TensorFlowTypeImpl { \ public: \ using TFBase::TFBase; \ + static constexpr StringLiteral name = #name_marg; \ }; -#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) +#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name_marg) #include "tensorflow/core/ir/types/types.def" namespace detail { @@ -222,6 +223,7 @@ inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) { class ResourceType : public detail::TypeWithSubtypeImpl { public: using TFBase::TFBase; + static constexpr ::mlir::StringLiteral name = "tf_type.resource"; static std::string getTypeName() { return "ResourceType"; } }; @@ -233,6 +235,7 @@ class ResourceType : public detail::TypeWithSubtypeImpl { class VariantType : public detail::TypeWithSubtypeImpl { public: using TFBase::TFBase; + static constexpr ::mlir::StringLiteral name = "tf_type.variant"; static std::string getTypeName() { return "VariantType"; } }; diff --git a/tensorflow/core/ir/types/dialect.td b/tensorflow/core/ir/types/dialect.td index e6a13969b6843b..417b870977782a 100644 --- a/tensorflow/core/ir/types/dialect.td +++ b/tensorflow/core/ir/types/dialect.td @@ -47,7 +47,6 @@ def TFTypeDialect : Dialect { void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const; }]; let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 0; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 504620a451c331..025a63d92354c2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -52,10 +52,8 @@ package( default_visibility = ["//visibility:public"], features = if_google( [ - "-layering_check", "-parse_headers", ], - ["-layering_check"], ), licenses = ["notice"], ) @@ -183,6 +181,7 @@ tf_kernel_library( "collective_nccl_reducer.h", "collective_nccl_reducer.cc", ]), + features = ["-layering_check"], prefix = "collective_ops", deps = [ "//tensorflow/core:core_cpu", @@ -202,6 +201,7 @@ tf_cuda_cc_test( name = "collective_nccl_test", size = "small", srcs = ["collective_nccl_test.cc"], + features = if_cuda(["-layering_check"]), tags = tf_cuda_tests_tags() + [ "guitar", "multi_gpu", @@ -265,6 +265,7 @@ cc_library( tf_kernel_library( name = "conv_2d", hdrs = ["conv_2d.h"], + features = if_cuda(["-layering_check"]), gpu_copts = if_not_windows([ "-Wno-pass-failed", # clang misses #pragma loop optimizations ]), @@ -317,6 +318,7 @@ cc_library( tf_kernel_library( name = "fill_functor", + features = ["-layering_check"], prefix = "fill_functor", deps = [ "//tensorflow/core:framework", @@ -372,6 +374,7 @@ cc_library( "sparse_utils.cc", ], hdrs = ["sparse_utils.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:framework_lite", @@ -382,6 +385,7 @@ cc_library( tf_cc_test( name = "sparse_utils_test", srcs = ["sparse_utils_test.cc"], + features = ["-layering_check"], deps = [ ":sparse_utils", "//tensorflow/core:framework", @@ -490,6 +494,7 @@ cc_library( tf_cuda_only_cc_test( name = "gpu_prim_helpers_test", srcs = ["gpu_prim_helpers_test.cu.cc"], + features = if_cuda(["-layering_check"]), tags = ["no_cuda_asan"], # TODO(b/183963619) deps = [ ":gpu_prim_helpers", @@ -538,6 +543,7 @@ tf_cuda_library( name = "gpu_utils", srcs = if_cuda_or_rocm(["gpu_utils.cc"]), hdrs = ["gpu_utils.h"], + features = ["-layering_check"], deps = [ ":gpu_util_hdrs", "//tensorflow/core:lib", @@ -626,6 +632,7 @@ cc_library( name = "queue_base", srcs = ["queue_base.cc"], hdrs = ["queue_base.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -637,6 +644,7 @@ cc_library( name = "queue_op", srcs = ["queue_op.cc"], hdrs = ["queue_op.h"], + features = ["-layering_check"], deps = [ ":queue_base", "//tensorflow/core:framework", @@ -877,6 +885,7 @@ tf_kernel_library( tf_kernel_library( name = "debug_ops", + features = ["-layering_check"], prefix = "debug_ops", deps = ARRAY_DEPS + [ "//tensorflow/core:gpu_runtime", @@ -948,6 +957,7 @@ tf_kernel_library( tf_kernel_library( name = "concat_op", + features = ["-layering_check"], prefix = "concat_op", deps = ARRAY_DEPS, ) @@ -1072,6 +1082,7 @@ tf_kernel_library( tf_kernel_library( name = "reshape_op", + features = ["-layering_check"], prefix = "reshape_op", deps = ARRAY_DEPS, ) @@ -1090,6 +1101,7 @@ tf_kernel_library( tf_kernel_library( name = "shape_ops", + features = ["-layering_check"], prefix = "shape_ops", deps = ARRAY_DEPS + ["//tensorflow/core/common_runtime:dma_helper"], ) @@ -1108,6 +1120,7 @@ tf_kernel_library( tf_kernel_library( name = "split_op", + features = ["-layering_check"], gpu_srcs = ["gpu_device_array.h"], prefix = "split_op", deps = ARRAY_DEPS + [ @@ -1118,6 +1131,7 @@ tf_kernel_library( tf_kernel_library( name = "split_v_op", + features = ["-layering_check"], gpu_srcs = ["gpu_device_array.h"], prefix = "split_v_op", deps = ARRAY_DEPS + [ @@ -1190,6 +1204,7 @@ tf_kernel_library( tf_kernel_library( name = "unique_op", + features = if_cuda(["-layering_check"]), prefix = "unique_op", deps = ARRAY_DEPS + [ "@com_google_absl//absl/container:flat_hash_map", @@ -1219,6 +1234,7 @@ tf_kernel_library( name = "where_op", srcs = ["where_op.cc"], hdrs = ["where_op.h"], + features = ["-layering_check"], gpu_srcs = [ "where_op.h", "where_op_gpu.cu.h", @@ -1290,6 +1306,17 @@ cc_library( ], ) +cc_library( + name = "ragged_utils", + hdrs = [ + "ragged_utils.h", + ], + deps = [ + "//tensorflow/core:framework", + "@com_google_absl//absl/status", + ], +) + tf_kernel_library( name = "ragged_gather_op", srcs = ["ragged_gather_op.cc"], @@ -1315,6 +1342,7 @@ tf_cc_test( tf_kernel_library( name = "ragged_range_op", srcs = ["ragged_range_op.cc"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", ], @@ -1323,6 +1351,7 @@ tf_kernel_library( tf_cc_test( name = "ragged_range_op_test", srcs = ["ragged_range_op_test.cc"], + features = ["-layering_check"], deps = [ ":ops_testutil", ":ragged_range_op", @@ -1336,6 +1365,7 @@ tf_cc_test( tf_kernel_library( name = "ragged_tensor_to_sparse_kernel", srcs = ["ragged_tensor_to_sparse_kernel.cc"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", ], @@ -1389,6 +1419,7 @@ cc_library( name = "ragged_tensor_variant", srcs = ["ragged_tensor_variant.cc"], hdrs = ["ragged_tensor_variant.h"], + features = ["-layering_check"], deps = [ ":cwise_op", "//tensorflow/core:framework", @@ -1401,6 +1432,7 @@ tf_kernel_library( deps = [ ":concat_lib", ":ragged_tensor_variant", + ":ragged_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1423,6 +1455,7 @@ tf_cc_test( "ragged_tensor_to_variant_op_test.cc", "ragged_tensor_to_variant_op_test.h", ], + features = ["-layering_check"], deps = [ ":ops_testutil", ":ragged_tensor_to_variant_op", @@ -1444,6 +1477,7 @@ tf_cc_test( "ragged_tensor_to_variant_op_large_data_test.cc", "ragged_tensor_to_variant_op_test.h", ], + features = ["-layering_check"], tags = [ "local", "manual", @@ -1482,7 +1516,9 @@ tf_cc_test( tf_kernel_library( name = "ragged_cross_op", srcs = ["ragged_cross_op.cc"], + features = ["-layering_check"], deps = [ + ":ragged_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1508,6 +1544,7 @@ tf_cc_test( name = "ragged_fill_empty_rows_op_test", size = "small", srcs = ["ragged_fill_empty_rows_op_test.cc"], + features = ["-layering_check"], deps = [ ":ops_testutil", ":ragged_fill_empty_rows_op", @@ -1570,10 +1607,10 @@ cc_library( testonly = 1, srcs = ["batch_kernel_test_util.cc"], hdrs = ["batch_kernel_test_util.h"], + features = ["-layering_check"], deps = [ ":batch_kernels", ":ops_testutil", - ":ops_util", "//tensorflow/core:test", "//tensorflow/core:testlib", ], @@ -1583,6 +1620,7 @@ tf_cc_test( name = "batch_kernels_test", size = "small", srcs = ["batch_kernels_test.cc"], + features = ["-layering_check"], deps = [ ":batch_kernel_test_util", ":batch_kernels", @@ -1606,8 +1644,10 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/platform:status", "//tensorflow/core/platform:status_matchers", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "@com_google_googletest//:gtest", ], ) @@ -1910,6 +1950,7 @@ tf_cuda_cc_test( name = "fused_batch_norm_ex_op_test", size = "small", srcs = ["fused_batch_norm_ex_op_test.cc"], + features = if_cuda(["-layering_check"]), tags = ["no_cuda_on_cpu_tap"], deps = [ ":cwise_op", @@ -1957,6 +1998,7 @@ tf_cc_test( tf_kernel_library( name = "gather_functor", + features = ["-layering_check"], prefix = "gather_functor", visibility = [":friends"], deps = [ @@ -2357,6 +2399,7 @@ tf_kernel_library( tf_cc_test( name = "while_op_test", srcs = ["while_op_test.cc"], + features = ["-layering_check"], tags = [ "no_windows", ], # TODO(b/208697533): Re-enable after fixing. @@ -2411,8 +2454,8 @@ tf_cc_test( ":control_flow_ops", ":ops_testutil", ":ops_util", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -2611,6 +2654,7 @@ DYNAMIC_DEPS = [ tf_kernel_library( name = "dynamic_partition_op", + features = if_cuda(["-layering_check"]), prefix = "dynamic_partition_op", deps = DYNAMIC_DEPS + [ ":fill_functor", @@ -2634,6 +2678,7 @@ cc_library( name = "tensor_cord", srcs = ["tensor_cord.cc"], hdrs = ["tensor_cord.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", "@com_google_absl//absl/strings", @@ -2806,6 +2851,7 @@ tf_kernel_library( name = "tensor_array", srcs = ["tensor_array.cc"], hdrs = ["tensor_array.h"], + features = ["-layering_check"], visibility = ["//visibility:private"], deps = [ ":aggregate_ops", @@ -2819,6 +2865,7 @@ tf_kernel_library( name = "resource_variable_ops", srcs = ["resource_variable_ops.cc"], hdrs = ["resource_variable_ops.h"], + features = ["-layering_check"], deps = [ ":dense_update_functor", ":gather_functor", @@ -2874,6 +2921,7 @@ tf_kernel_library( name = "list_kernels", srcs = ["list_kernels.cc"], hdrs = ["list_kernels.h"], + features = ["-layering_check"], gpu_srcs = [ "list_kernels.cu.cc", "list_kernels.h", @@ -2894,6 +2942,7 @@ cc_library( name = "tensor_map", srcs = ["tensor_map.cc"], hdrs = ["tensor_map.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -2947,6 +2996,7 @@ tf_kernel_library( tf_kernel_library( name = "function_ops", + features = ["-layering_check"], prefix = "function_ops", deps = [ "//tensorflow/core:core_cpu", @@ -3024,6 +3074,7 @@ tf_cc_test( ":no_mkldnn_contraction_kernel": [], "//conditions:default": ["eigen_mkldnn_contraction_kernel_test.cc"], }), + features = ["-layering_check"], tags = ["mkldnn_contraction_kernel"], deps = [ "//tensorflow/core:test", @@ -3167,6 +3218,7 @@ tf_kernel_library( hdrs = [ "checkpoint_callback_manager.h", ], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", "//tensorflow/core/platform:regexp", @@ -3178,6 +3230,7 @@ tf_cc_tests( name = "checkpoint_callback_manager_test", size = "small", srcs = ["checkpoint_callback_manager_test.cc"], + features = ["-layering_check"], deps = [ ":checkpoint_callback_manager", ":io", @@ -3368,6 +3421,7 @@ tf_cc_test( name = "resource_ops_test", size = "small", srcs = ["resource_ops_test.cc"], + features = ["-layering_check"], deps = [ ":dense_update_functor", ":ops_testutil", @@ -3382,6 +3436,7 @@ tf_cc_test( name = "lookup_ops_test", size = "small", srcs = ["lookup_ops_test.cc"], + features = ["-layering_check"], deps = [ ":lookup_table_op", ":ops_testutil", @@ -3465,6 +3520,7 @@ tf_kernel_library( tf_kernel_library( name = "matmul_op", + features = ["-layering_check"], prefix = "matmul_op", textual_hdrs = ["matmul_op_impl.h"], deps = MATH_DEPS + [ @@ -3495,6 +3551,7 @@ cc_library( name = "matmul_util", srcs = ["matmul_util.cc"], hdrs = ["matmul_util.h"], + features = ["-layering_check"], local_defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]), deps = if_cuda_or_rocm([ "@com_google_absl//absl/container:flat_hash_map", @@ -3520,6 +3577,7 @@ tf_kernel_library( tf_kernel_library( name = "bucketize_op", + features = if_cuda(["-layering_check"]), gpu_srcs = ["gpu_device_array.h"], prefix = "bucketize_op", deps = ARRAY_DEPS, @@ -3530,6 +3588,7 @@ tf_kernel_library( copts = if_mlir_generated_gpu_kernels_enabled( ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED"], ), + features = ["-layering_check"], # *.cu.cc sources are compiled with gpu_copts instead of copts. gpu_copts = if_mlir_generated_gpu_kernels_enabled( ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED"], @@ -3542,6 +3601,7 @@ tf_kernel_library( tf_kernel_library( name = "check_numerics_op", + features = ["-layering_check"], prefix = "check_numerics_op", deps = MATH_DEPS + ["//tensorflow/core:framework_internal"], ) @@ -3576,6 +3636,7 @@ tf_kernel_library( copts = if_mlir_generated_gpu_kernels_enabled( ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED"], ), + features = if_cuda(["-layering_check"]), # *.cu.cc sources are compiled with gpu_copts instead of copts. gpu_copts = if_mlir_generated_gpu_kernels_enabled( ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED"], @@ -3595,6 +3656,7 @@ tf_kernel_library( tf_kernel_library( name = "fft_ops", + features = ["-layering_check"], prefix = "fft_ops", deps = MATH_DEPS + if_cuda([ "@com_google_absl//absl/container:flat_hash_map", @@ -3605,6 +3667,7 @@ tf_kernel_library( tf_kernel_library( name = "reduction_ops", + features = if_cuda(["-layering_check"]), gpu_srcs = ["reduction_gpu_kernels.cu.h"], prefix = "reduction_ops", deps = MATH_DEPS + [ @@ -3615,6 +3678,7 @@ tf_kernel_library( tf_kernel_library( name = "segment_reduction_ops", + features = ["-layering_check"], prefix = "segment_reduction_ops", deps = MATH_DEPS + [ "//tensorflow/core/util:determinism_for_kernels", @@ -3631,6 +3695,7 @@ tf_kernel_library( name = "scan_ops", srcs = ["scan_ops.cc"], hdrs = ["scan_ops.h"], + features = if_cuda(["-layering_check"]), gpu_srcs = [ "scan_ops.h", "scan_ops_gpu.h", @@ -3999,6 +4064,7 @@ tf_kernel_library( defines = [ "EIGEN_NEON_GEBP_NR=4", ], + features = ["-layering_check"], prefix = "conv_ops", textual_hdrs = [ "autotune_conv_impl.h", @@ -4067,6 +4133,7 @@ tf_kernel_library( name = "depthwise_conv_op", srcs = ["depthwise_conv_op.cc"], hdrs = ["depthwise_conv_op.h"], + features = ["-layering_check"], gpu_copts = if_not_windows([ "-Wno-pass-failed", # clang misses #pragma loop optimizations ]), @@ -4102,6 +4169,7 @@ tf_kernel_library( hdrs = [ "depthwise_conv_op.h", ], + features = ["-layering_check"], prefix = "depthwise_conv_grad_op", deps = [ ":cast_op", @@ -4174,6 +4242,7 @@ tf_kernel_library( tf_kernel_library( name = "bias_op", + features = ["-layering_check"], prefix = "bias_op", deps = NN_DEPS + [ ":loose_headers", @@ -4194,6 +4263,7 @@ tf_kernel_library( tf_kernel_library( name = "fused_batch_norm_op", + features = ["-layering_check"], prefix = "fused_batch_norm_op", deps = NN_DEPS + [ ":cast_op", @@ -4208,12 +4278,14 @@ tf_kernel_library( tf_kernel_library( name = "in_topk_op", + features = if_cuda(["-layering_check"]), prefix = "in_topk_op", deps = NN_DEPS + [":reduction_ops"], ) tf_kernel_library( name = "lrn_op", + features = ["-layering_check"], prefix = "lrn_op", deps = NN_DEPS + if_rocm([":conv_ops_gpu_hdrs"]) + [":loose_headers"], ) @@ -4225,6 +4297,7 @@ tf_kernel_library( ) + if_mlir_generated_gpu_kernels_enabled( ["-DMLIR_GENERATED_GPU_KERNELS_ENABLED"], ), + features = if_cuda(["-layering_check"]), # *.cu.cc sources are compiled with gpu_copts instead of copts. gpu_copts = if_mlir_generated_experimental_kernels_enabled( ["-DMLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED"], @@ -4345,6 +4418,7 @@ tf_kernel_library( tf_kernel_library( name = "l2loss_op", + features = if_cuda(["-layering_check"]), prefix = "l2loss_op", deps = [ ":gpu_prim_hdrs", @@ -4440,6 +4514,7 @@ tf_kernel_library( "pooling_ops_3d.h", "pooling_ops_common.h", ], + features = ["-layering_check"], gpu_srcs = [ "avgpooling_op.h", "avgpooling_op_gpu.cu.cc", @@ -4686,6 +4761,7 @@ cc_library( tf_kernel_library( name = "random_op", + features = ["-layering_check"], prefix = "random_op", deps = RANDOM_OPS_DEPS, ) @@ -4771,6 +4847,7 @@ cc_library( tf_kernel_library( name = "stateful_random_ops", + features = if_cuda(["-layering_check"]), prefix = "stateful_random_ops", deps = [ ":dense_update_functor", @@ -4797,6 +4874,7 @@ tf_kernel_library( tf_kernel_library( name = "stateless_random_gamma_op", + features = ["-layering_check"], prefix = "stateless_random_gamma_op", deps = [ ":stateless_random_ops", @@ -4840,6 +4918,7 @@ tf_cc_test( tf_kernel_library( name = "random_index_shuffle_ops", + features = ["-layering_check"], prefix = "random_index_shuffle_ops", deps = [ ":random_index_shuffle", @@ -4940,6 +5019,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_concat_op", + features = ["-layering_check"], prefix = "sparse_concat_op", deps = SPARSE_DEPS + if_cuda_or_rocm([ ":gpu_device_array", @@ -4964,6 +5044,7 @@ tf_kernel_library( tf_kernel_library( name = "fill_empty_rows_functor", + features = if_cuda(["-layering_check"]), prefix = "fill_empty_rows_functor", deps = [ "//tensorflow/core:framework", @@ -4979,6 +5060,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_cross_op", + features = ["-layering_check"], prefix = "sparse_cross_op", deps = SPARSE_DEPS + [ "@eigen_archive//:eigen3", @@ -5011,6 +5093,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_reorder_op", + features = if_cuda(["-layering_check"]), prefix = "sparse_reorder_op", deps = SPARSE_DEPS + if_cuda_or_rocm([ ":gpu_prim_hdrs", @@ -5028,6 +5111,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_slice_grad_op", + features = if_cuda(["-layering_check"]), prefix = "sparse_slice_grad_op", deps = SPARSE_DEPS + if_cuda_or_rocm([ ":gpu_prim_hdrs", @@ -5036,6 +5120,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_slice_op", + features = if_cuda(["-layering_check"]), prefix = "sparse_slice_op", deps = SPARSE_DEPS + if_cuda_or_rocm([ ":gpu_device_array", @@ -5057,6 +5142,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_split_op", + features = if_cuda(["-layering_check"]), prefix = "sparse_split_op", deps = SPARSE_DEPS + if_cuda_or_rocm([ ":gpu_device_array", @@ -5090,6 +5176,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_to_dense_op", + features = ["-layering_check"], prefix = "sparse_to_dense_op", deps = SPARSE_DEPS + [ ":loose_headers", @@ -5102,6 +5189,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_xent_op", + features = if_cuda(["-layering_check"]), gpu_copts = tf_disable_ptxas_warning_flags(), prefix = "sparse_xent_op", deps = SPARSE_DEPS + [ @@ -5147,6 +5235,7 @@ tf_kernel_library( tf_kernel_library( name = "sparse_tensors_map_ops", + features = ["-layering_check"], prefix = "sparse_tensors_map_ops", deps = SPARSE_DEPS, ) @@ -5323,6 +5412,7 @@ cc_library( name = "scatter_nd_util", srcs = ["scatter_nd_util.cc"], hdrs = ["scatter_nd_util.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", ], @@ -5344,6 +5434,7 @@ tf_kernel_library( hdrs = [ "scatter_nd_op.h", ], + features = ["-layering_check"], gpu_copts = if_not_windows([ "-Wno-pass-failed", # clang misses #pragma loop optimizations ]), @@ -5365,6 +5456,7 @@ tf_kernel_library( tf_kernel_library( name = "variable_ops", + features = ["-layering_check"], prefix = "variable_ops", deps = STATE_DEPS, ) @@ -5474,7 +5566,7 @@ tf_kernel_library( name = "tensor_to_hash_bucket_op", prefix = "tensor_to_hash_bucket_op", deps = STRING_DEPS + if_oss( - if_cuda(["@farmhash_gpu_archive//:farmhash_gpu"]), + if_cuda_or_rocm(["@farmhash_gpu_archive//:farmhash_gpu"]), tf_fingerprint_deps(), ), ) @@ -5656,6 +5748,7 @@ tf_cc_test( tf_kernel_library( name = "as_string_op", + features = ["-layering_check"], prefix = "as_string_op", deps = STRING_DEPS, ) @@ -5730,6 +5823,7 @@ tf_cc_test( tf_kernel_library( name = "multinomial_op", + features = if_cuda(["-layering_check"]), prefix = "multinomial_op", deps = [ ":gpu_prim_hdrs", @@ -5763,6 +5857,7 @@ tf_cuda_cc_test( tf_kernel_library( name = "parameterized_truncated_normal_op", + features = if_cuda(["-layering_check"]), gpu_copts = if_not_windows([ "-Wno-pass-failed", # clang misses #pragma loop optimizations ]), @@ -6011,6 +6106,7 @@ tf_cuda_cc_test( name = "spectrogram_op_test", size = "small", srcs = ["spectrogram_op_test.cc"], + features = ["-layering_check"], deps = [ ":ops_util", ":spectrogram_op", @@ -6381,6 +6477,7 @@ filegroup( "fill_empty_rows_functor.h", "function_ops.h", "fused_batch_norm_op.h", + "gpu_utils.h", "inplace_ops.cc", "inplace_ops_functor.h", "l2loss_op.h", @@ -6397,6 +6494,7 @@ filegroup( "partitioned_function_ops.h", "pooling_ops_3d.h", "ragged_tensor_variant.h", + "ragged_utils.h", "random_index_shuffle.h", "random_op.h", "random_poisson_op.h", @@ -6978,6 +7076,7 @@ cc_library( ]), copts = tf_copts() + tf_opts_nortti_if_lite_protos(), defines = ["EIGEN_NEON_GEBP_NR=4"], + features = ["-layering_check"], linkopts = if_android(["-ldl"]), tags = [ "manual", @@ -7071,6 +7170,7 @@ tf_kernel_library( "reshape_op.h", ], hdrs = ["reference_gemm.h"], + features = ["-layering_check"], deps = [ ":concat_lib_hdrs", ":conv_ops", @@ -7163,6 +7263,7 @@ tf_cc_binary( testonly = 1, srcs = ["quantization_utils_test.cc"], copts = tf_copts(), + features = ["-layering_check"], linkopts = select({ "//tensorflow:android": [ "-lm", @@ -7222,6 +7323,7 @@ cc_binary( testonly = 1, srcs = ["quantized_add_op_test.cc"], copts = tf_copts(), + features = ["-layering_check"], linkopts = select({ "//tensorflow:android": [ "-lm", @@ -7307,6 +7409,7 @@ cc_binary( testonly = 1, srcs = ["quantized_resize_bilinear_op_test.cc"], copts = tf_copts(), + features = ["-layering_check"], linkopts = select({ "//tensorflow:android": [ "-lm", @@ -7428,6 +7531,7 @@ cc_binary( name = "quantized_mul_op_test_android_only", testonly = 1, srcs = ["quantized_mul_op_test.cc"], + features = ["-layering_check"], linkopts = select({ "//tensorflow:android": [ "-pie", @@ -7622,6 +7726,7 @@ cc_library( name = "quantization_utils", srcs = ["quantization_utils.cc"], hdrs = ["quantization_utils.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework", "@gemmlowp", @@ -7734,6 +7839,7 @@ tf_kernel_library( tf_kernel_library( name = "sync_ops", + features = ["-layering_check"], prefix = "sync_ops", deps = [ "//tensorflow/core:framework", @@ -7816,6 +7922,7 @@ cc_library( tf_kernel_library( name = "stochastic_cast_op", + features = ["-layering_check"], prefix = "stochastic_cast_op", deps = [ ":stateless_random_ops_v2_util", @@ -7829,6 +7936,7 @@ tf_cc_test( name = "stochastic_cast_op_test", timeout = "moderate", srcs = ["stochastic_cast_op_test.cc"], + features = ["-layering_check"], shard_count = 48, deps = [ ":cwise_lib", diff --git a/tensorflow/core/kernels/batch_kernel_test_util.cc b/tensorflow/core/kernels/batch_kernel_test_util.cc index e7d35ec2e4779c..bda3c25b182973 100644 --- a/tensorflow/core/kernels/batch_kernel_test_util.cc +++ b/tensorflow/core/kernels/batch_kernel_test_util.cc @@ -15,51 +15,53 @@ limitations under the License. #include "tensorflow/core/kernels/batch_kernel_test_util.h" +#include + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/kernels/batch_kernels.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/status.h" + namespace tensorflow { -namespace internal { +namespace test_util { BatchFunctionKernelTestAccess::BatchFunctionKernelTestAccess( - BatchFunctionKernel* kernel) + const BatchFunctionKernel* kernel) : kernel_(kernel) {} bool BatchFunctionKernelTestAccess::enable_adaptive_batch_threads() const { return kernel_->enable_adaptive_batch_threads_; } -} // namespace internal - -bool BatchFunctionKernelTestBase::enable_adaptive_scheduler() const { - return GetParam(); -} - -Status BatchFunctionKernelTestBase::Init() { +Status BatchFunctionKernelTestBase::Init(bool enable_adaptive_scheduler) { std::vector input_dtypes({DataType::DT_INT64, DataType::DT_INT64}); std::vector inputs( {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64}), NodeDefBuilder::NodeOut({"n2", 1, DataType::DT_INT64})}); NameAttrList f; f.set_name("func_to_batch"); - TF_CHECK_OK( - NodeDefBuilder("BatchTPUInput", "BatchFunction") - .Attr("max_batch_size", 32) - .Attr("num_batch_threads", enable_adaptive_scheduler() ? 0 : 8) - .Attr("allowed_batch_sizes", {2, 4, 8}) - .Attr("batch_timeout_micros", 1000) - .Attr("max_enqueued_batches", 100) - .Attr("enable_large_batch_splitting", true) - .Attr("low_priority_max_batch_size", 64) - .Attr("low_priority_batch_timeout_micros", 8000) - .Attr("low_priority_allowed_batch_sizes", {32, 64}) - .Attr("low_priority_max_enqueued_batches", 1000) - .Attr("Tcaptured", std::vector{DataType::DT_INT64}) - .Attr("Tin", input_dtypes) - .Input(inputs) - .Attr("Tcaptured", std::vector{DataType::DT_INT64}) - .Input(std::vector{ - NodeDefBuilder::NodeOut({"n3", 1, DataType::DT_INT64})}) - .Attr("Tout", std::vector(4, DataType::DT_INT64)) - .Attr("f", f) - .Finalize(node_def())); + TF_CHECK_OK(NodeDefBuilder("BatchTPUInput", "BatchFunction") + .Attr("max_batch_size", 32) + .Attr("num_batch_threads", enable_adaptive_scheduler ? 0 : 8) + .Attr("allowed_batch_sizes", {2, 4, 8}) + .Attr("batch_timeout_micros", 1000) + .Attr("max_enqueued_batches", 100) + .Attr("enable_large_batch_splitting", true) + .Attr("low_priority_max_batch_size", 64) + .Attr("low_priority_batch_timeout_micros", 8000) + .Attr("low_priority_allowed_batch_sizes", {32, 64}) + .Attr("low_priority_max_enqueued_batches", 1000) + .Attr("Tcaptured", std::vector{DataType::DT_INT64}) + .Attr("Tin", input_dtypes) + .Input(inputs) + .Attr("Tcaptured", std::vector{DataType::DT_INT64}) + .Input(std::vector{ + NodeDefBuilder::NodeOut({"n3", 1, DataType::DT_INT64})}) + .Attr("Tout", std::vector(4, DataType::DT_INT64)) + .Attr("f", f) + .Finalize(node_def())); return InitOp(); } + +} // namespace test_util } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_kernel_test_util.h b/tensorflow/core/kernels/batch_kernel_test_util.h index e26f6c5d78914c..e6b37e635ac0bc 100644 --- a/tensorflow/core/kernels/batch_kernel_test_util.h +++ b/tensorflow/core/kernels/batch_kernel_test_util.h @@ -16,37 +16,33 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNEL_TEST_UTIL_H_ #define TENSORFLOW_CORE_KERNELS_BATCH_KERNEL_TEST_UTIL_H_ -#include "tensorflow/core/framework/node_def_builder.h" +#include #include "tensorflow/core/kernels/batch_kernels.h" #include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { -namespace internal { +namespace test_util { + +// A test util for accessing private members of `BatchFunctionKernel`. class BatchFunctionKernelTestAccess { public: - explicit BatchFunctionKernelTestAccess(BatchFunctionKernel* kernel); + explicit BatchFunctionKernelTestAccess(const BatchFunctionKernel* kernel); bool enable_adaptive_batch_threads() const; private: - BatchFunctionKernel* const kernel_; + const BatchFunctionKernel* const kernel_; }; -} // namespace internal - class BatchFunctionKernelTestBase : public OpsTestBase, public ::testing::WithParamInterface { public: - bool enable_adaptive_scheduler() const; - // Init test fixture with a batch kernel instance. - Status Init(); + Status Init(bool enable_adaptive_scheduler); }; +} // namespace test_util } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_BATCH_KERNEL_TEST_UTIL_H_ diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 1763fcd3c15088..8862e5b0c98f58 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -218,7 +218,7 @@ class BatchResource : public serving::BatchResourceBase { has_process_batch_function, std::move(batcher), GetAdaptiveBatcherQueueOptions( max_batch_size, batch_timeout_micros, max_enqueued_batches, - true /* enable large batch split */, allowed_batch_sizes, + /*enable_large_batch_splitting=*/true, allowed_batch_sizes, /*disable_padding=*/false), allowed_batch_sizes)); return OkStatus(); @@ -302,9 +302,6 @@ BatchFunctionKernel::BatchFunctionKernel(OpKernelConstruction* c) OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting", &enable_large_batch_splitting_)); has_attribute_enable_large_batch_splitting_ = true; - } else { - enable_large_batch_splitting_ = false; - has_attribute_enable_large_batch_splitting_ = false; } // Helper function `SetAdaptiveBatchSchedulerOptions` calls diff --git a/tensorflow/core/kernels/batch_kernels.h b/tensorflow/core/kernels/batch_kernels.h index 9ea1b195a050f6..1c9c35356e3d2f 100644 --- a/tensorflow/core/kernels/batch_kernels.h +++ b/tensorflow/core/kernels/batch_kernels.h @@ -34,9 +34,9 @@ ABSL_CONST_INIT extern const int64_t kInitialInflightBatches; ABSL_CONST_INIT extern const int64_t kBatchesToAverageOver; ABSL_CONST_INIT extern const int64_t kMaxInflightBatches; -namespace internal { +namespace test_util { class BatchFunctionKernelTestAccess; -} +} // namespace test_util // Records the usage of attribute `enable_large_batch_splitting`. void RecordBatchSplitUsage( @@ -71,7 +71,7 @@ class BatchFunctionKernel : public AsyncOpKernel { void ComputeAsync(OpKernelContext* c, DoneCallback done) final; private: - friend class internal::BatchFunctionKernelTestAccess; + friend class test_util::BatchFunctionKernelTestAccess; // Validates 'allowed_batch_sizes_'. The entries must increase monotonically. // If large batch split is not enabled, the last one must equal @@ -111,8 +111,8 @@ class BatchFunctionKernel : public AsyncOpKernel { std::vector low_priority_allowed_batch_sizes_; NameAttrList func_; absl::optional fhandle_ TF_GUARDED_BY(mu_); - bool enable_large_batch_splitting_; - bool has_attribute_enable_large_batch_splitting_; + bool enable_large_batch_splitting_ = false; + bool has_attribute_enable_large_batch_splitting_ = false; bool enable_adaptive_batch_threads_ = false; mutex mu_; diff --git a/tensorflow/core/kernels/batch_kernels_env_test.cc b/tensorflow/core/kernels/batch_kernels_env_test.cc index 8b2819c0a6be3f..508c0e8699763c 100644 --- a/tensorflow/core/kernels/batch_kernels_env_test.cc +++ b/tensorflow/core/kernels/batch_kernels_env_test.cc @@ -13,29 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "tensorflow/core/kernels/batch_kernel_test_util.h" -#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status_matchers.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/lib/core/status_test_util.h" namespace tensorflow { +namespace { // Tests that batch kernel initialization returns error when it's configured to // use adaptive scheduling yet batching thread pool creation fails. -class BatchFunctionKernelEnvTest : public BatchFunctionKernelTestBase {}; +class BatchFunctionKernelEnvTest + : public test_util::BatchFunctionKernelTestBase {}; TEST_P(BatchFunctionKernelEnvTest, Basic) { tensorflow::setenv("TF_NUM_BATCH_THREADS", "0", 1 /* overwrite */); - if (enable_adaptive_scheduler()) { - EXPECT_THAT(Init(), tensorflow::testing::StatusIs( + + const bool adaptive_scheduler_enabled = GetParam(); + Status status = Init(adaptive_scheduler_enabled); + if (adaptive_scheduler_enabled) { + EXPECT_THAT(status, tensorflow::testing::StatusIs( error::FAILED_PRECONDITION, "Failed to create batch threads pool")); } else { // Initialization is ok since batch kernel doesn't use adaptive // scheduler. - TF_EXPECT_OK(Init()); + TF_EXPECT_OK(status); } } INSTANTIATE_TEST_SUITE_P(Params, BatchFunctionKernelEnvTest, ::testing::Bool()); + +} // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index 7b7810780872a7..af7546a062169d 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -24,30 +24,38 @@ limitations under the License. #include "absl/strings/match.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/batch_kernel_test_util.h" #include "tensorflow/core/kernels/batching_util/warmup.h" -#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" namespace tensorflow { +namespace { using PerModelData = serving::WarmupStateRegistry::PerModelData; -class BatchFunctionKernelTest : public BatchFunctionKernelTestBase {}; +class BatchFunctionKernelTest : public test_util::BatchFunctionKernelTestBase { +}; TEST_P(BatchFunctionKernelTest, EnableAdaptiveScheduler) { - TF_EXPECT_OK(Init()); + const bool adaptive_scheduler_enabled = GetParam(); + + TF_EXPECT_OK(Init(adaptive_scheduler_enabled)); + BatchFunctionKernel *batch_kernel = dynamic_cast(op_kernel()); - EXPECT_EQ(internal::BatchFunctionKernelTestAccess(batch_kernel) - .enable_adaptive_batch_threads(), - enable_adaptive_scheduler()); + EXPECT_EQ(adaptive_scheduler_enabled, + test_util::BatchFunctionKernelTestAccess(batch_kernel) + .enable_adaptive_batch_threads()); } INSTANTIATE_TEST_SUITE_P(Params, BatchFunctionKernelTest, ::testing::Bool()); @@ -55,51 +63,68 @@ INSTANTIATE_TEST_SUITE_P(Params, BatchFunctionKernelTest, ::testing::Bool()); class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { public: // Init test fixture with a batch kernel instance. - Status Init(bool enable_splitting, bool check_output_shape = true) { + Status Init(bool enable_splitting, bool check_output_shape) { static auto *const cpu_device = []() { auto device = DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); return device.release(); }(); - // Overriding the per-test/per-op device with a global device so that it can + // Override the per-test/per-op device with a global device so that it can // be shared between ops. device_ = cpu_device; - std::vector input_dtypes({DataType::DT_INT64}); - std::vector inputs( - {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); - NameAttrList f; - f.set_name("func_to_batch"); - tensorflow::FunctionDefHelper::Node node_info = { - {"output1"}, "Identity", {"input1"}, {{"T", DT_INT64}}}; + f.set_name("BatchFunctionKernelParallelWarmupTestStateFunc"); + FunctionDef func; if (check_output_shape) { - node_info = {{"output1"}, - "EnsureShape", - {"input1"}, - {{"T", DT_INT64}, {"shape", TensorShape({2})}}}; + func = FunctionDefHelper::Create( + // function_name + f.name(), + // in_def + {"x:int64"}, + // out_def + {"o:int64"}, + // attr_def + {}, + // node_def + {{{"o"}, + "EnsureShape", + {"x"}, + {{"T", DataType::DT_INT64}, {"shape", TensorShape({2})}}}}, + // ret_def + {{"o", "o:output"}}); + } else { + func = FunctionDefHelper::Create( + // function_name + f.name(), + // in_def + {"x:int64"}, + // out_def + {"o:int64"}, + // attr_def + {}, + // node_def + {{{"o"}, "Identity", {"x"}, {{"T", DataType::DT_INT64}}}}, + // ret_def + {{"o", "o:output"}}); } - TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(FunctionDefHelper::Define( - /*Function*/ "func_to_batch", - /*Inputs*/ {"input1:int64"}, - /*Outputs*/ {"output1:int64"}, - /*Attribute*/ {}, - // Node info - {node_info}))); + TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func)); pflr_ = std::make_unique( device_mgr_.get(), Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), /*thread_pool=*/nullptr, /*parent=*/nullptr, /*session_metadata=*/nullptr, - Rendezvous::Factory{[](const int64, const DeviceMgr *device_mgr, + Rendezvous::Factory{[](const int64_t, const DeviceMgr *device_mgr, tsl::core::RefCountPtr *r) { *r = tsl::core::RefCountPtr( new IntraProcessRendezvous(device_mgr)); return OkStatus(); }}); + std::vector inputs( + {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); TF_CHECK_OK(NodeDefBuilder("BatchTPUInput", "BatchFunction") .Attr("max_batch_size", enable_splitting ? 16 : 8) .Attr("num_batch_threads", 8) @@ -111,7 +136,7 @@ class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { .Attr("low_priority_batch_timeout_micros", 8000) .Attr("low_priority_allowed_batch_sizes", {32, 64}) .Attr("low_priority_max_enqueued_batches", 1000) - .Attr("Tin", input_dtypes) + .Attr("Tin", {DataType::DT_INT64}) .Input(inputs) .Attr("Tcaptured", std::vector{}) .Input(std::vector{}) @@ -150,7 +175,8 @@ TEST_P(BatchFunctionKernelParallelWarmupTest, ParallelWarmup) { Env::Default()->SchedClosure([&]() { BatchFunctionKernelParallelWarmupTestState test; test.set_session_metadata(session_metadata); - TF_CHECK_OK(test.Init(enable_splitting)); + TF_CHECK_OK(test.Init(enable_splitting, + /*check_output_shape=*/true)); test.AddInputFromList(TensorShape({2}), {123, 456}); TF_CHECK_OK(test.RunOpKernel()); @@ -171,7 +197,8 @@ TEST_P(BatchFunctionKernelParallelWarmupTest, ParallelWarmup) { Env::Default()->SchedClosure([&]() { BatchFunctionKernelParallelWarmupTestState test; test.set_session_metadata(session_metadata); - TF_CHECK_OK(test.Init(enable_splitting)); + TF_CHECK_OK(test.Init(enable_splitting, + /*check_output_shape=*/true)); test.AddInputFromList(TensorShape({2}), {123, 456}); // We expect requests to be batched together when the warm-up mode is // turned off, which will make the execution fail at `EnsureShape`. @@ -205,7 +232,7 @@ TEST_P(BatchFunctionKernelParallelWarmupTest, ParallelWarmupAutoBatch) { Env::Default()->SchedClosure([&]() { BatchFunctionKernelParallelWarmupTestState test; test.set_session_metadata(session_metadata); - TF_CHECK_OK(test.Init(enable_splitting)); + TF_CHECK_OK(test.Init(enable_splitting, /*check_output_shape=*/true)); test.AddInputFromList(TensorShape({2}), {123, 456}); auto status = test.RunOpKernel(); ASSERT_FALSE(status.ok()); @@ -250,5 +277,5 @@ TEST_P(BatchFunctionKernelParallelWarmupTest, ParallelWarmupAutoBatch) { INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelParallelWarmupTestSuite, BatchFunctionKernelParallelWarmupTest, ::testing::Bool()); - +} // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_norm_op_test.cc b/tensorflow/core/kernels/batch_norm_op_test.cc index 45ddc853295557..7b96122b521ae5 100644 --- a/tensorflow/core/kernels/batch_norm_op_test.cc +++ b/tensorflow/core/kernels/batch_norm_op_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include + #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -29,60 +30,50 @@ limitations under the License. namespace tensorflow { -class BatchNormOpTest : public OpsTestBase {}; - -TEST_F(BatchNormOpTest, Simple) { - TF_EXPECT_OK( - NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Attr("scale_after_normalization", false) - .Attr("variance_epsilon", 0.001) - .Finalize(node_def())); - TF_EXPECT_OK(InitOpWithGraphVersion(8)); - AddInputFromArray(TensorShape({1, 1, 6, 2}), - {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); - AddInputFromArray(TensorShape({2}), {10, 20}); - AddInputFromArray(TensorShape({2}), {0.25f, 0.5f}); - AddInputFromArray(TensorShape({2}), {0.1f, 0.6f}); - AddInputFromArray(TensorShape({2}), {0.0f, 0.0f}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); - test::FillValues( - &expected, {-17.86f, -22.00f, -15.87f, -20.59f, -13.87f, -19.18f, -21.86f, - -33.31f, -23.85f, -34.72f, -25.85f, -36.13f}); - test::ExpectTensorNear(expected, *GetOutput(0), 0.01); -} - -TEST_F(BatchNormOpTest, Fp16) { - TF_EXPECT_OK( - NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Attr("scale_after_normalization", false) - .Attr("variance_epsilon", 0.001) - .Finalize(node_def())); - TF_EXPECT_OK(InitOpWithGraphVersion(8)); - AddInputFromList(TensorShape({1, 1, 6, 2}), - {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); - AddInputFromList(TensorShape({2}), {10, 20}); - AddInputFromList(TensorShape({2}), {0.25, 0.5}); - AddInputFromList(TensorShape({2}), {0.1, 0.6}); - AddInputFromList(TensorShape({2}), {0.0, 0.0}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected(allocator(), DT_HALF, TensorShape({1, 1, 6, 2})); - test::FillValues( - &expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86, - -33.31, -23.85, -34.72, -25.85, -36.13}); - test::ExpectTensorNear(expected, *GetOutput(0), 0.1); -} +template +struct BatchNormOpTest : public OpsTestBase { + static constexpr auto TValueType = DataTypeToEnum::value; + + void run_me() { + TF_EXPECT_OK( + NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Attr("scale_after_normalization", false) + .Attr("variance_epsilon", 0.001) + .Finalize(node_def())); + TF_EXPECT_OK(InitOpWithGraphVersion(8)); + + AddInputFromList(TensorShape({1, 1, 6, 2}), + {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); + AddInputFromList(TensorShape({2}), {10, 20}); + AddInputFromList(TensorShape({2}), {0.25, 0.5}); + AddInputFromList(TensorShape({2}), {0.1, 0.6}); + AddInputFromList(TensorShape({2}), {0.0, 0.0}); + + TF_ASSERT_OK(RunOpKernel()); + + double atol = TValueType == DT_FLOAT ? 0.01 : 0.1; + + Tensor expected(allocator(), TValueType, TensorShape({1, 1, 6, 2})); + test::FillValues(&expected, + {-17.86f, -22.00f, -15.87f, -20.59f, -13.87f, -19.18f, + -21.86f, -33.31f, -23.85f, -34.72f, -25.85f, -36.13f}); + test::ExpectTensorNear(expected, *GetOutput(0), atol); + } +}; + +TYPED_TEST_SUITE_P(BatchNormOpTest); + +TYPED_TEST_P(BatchNormOpTest, Simple) { this->run_me(); } + +REGISTER_TYPED_TEST_SUITE_P(BatchNormOpTest, Simple); + +// TODO(ezhulenev): Add support for more data types. +using DataTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, BatchNormOpTest, DataTypes); } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index c1395ed464252c..fa900a9c87789c 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -234,6 +234,23 @@ void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes, cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes); } +void RecordBatchCosts(const std::string& model_name, + const int64_t processed_size, + const absl::string_view cost_type, + const absl::Duration total_cost) { + static auto* cell = tensorflow::monitoring::Sampler<3>::New( + {"/tensorflow/serving/batching/costs", + "Tracks the batch costs (in microseconds) by model name and processed " + "size.", + "model_name", "processed_size", "cost_type"}, + // It's 27 buckets with the last bucket being 2^26 to DBL_MAX; + // so the limits are [1, 2, 4, 8, ..., 64 * 1024 * 1024 (~64s), DBL_MAX]. + monitoring::Buckets::Exponential(1, 2, 27)); + cell->GetCell(model_name, std::to_string(processed_size), + std::string(cost_type)) + ->Add(absl::ToDoubleMicroseconds(total_cost)); +} + const string& GetModelName(OpKernelContext* ctx) { static string* kModelNameUnset = new string("model_name_unset"); if (!ctx->session_metadata()) return *kModelNameUnset; @@ -485,6 +502,7 @@ BatchResourceBase::GetBatcherQueueOptions( *allowed_batch_sizes.rbegin(); batcher_queue_options.high_priority_queue_options .max_execution_batch_size = *allowed_batch_sizes.rbegin(); + batcher_queue_options.allowed_batch_sizes = allowed_batch_sizes; } if (low_priority_allowed_batch_sizes.empty()) { batcher_queue_options.low_priority_queue_options @@ -827,6 +845,7 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { auto& last_task = batch->task(batch->num_tasks() - 1); OpKernelContext* last_task_context = last_task.context; + const std::string& model_name = GetModelName(last_task_context); // Regardless of the outcome, we need to propagate the status to the // individual tasks and signal that they are done. We use MakeCleanup() to @@ -838,8 +857,8 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { if (cleanup_done) { return; } - SplitBatchCostsAndRecordMetrics(batch_cost_measurements, processed_size, - *batch); + SplitBatchCostsAndRecordMetrics(model_name, batch_cost_measurements, + processed_size, *batch); // Clear the measurements before unblocking the batch task, as measurements // are associated with the task's thread context. batch_cost_measurements.clear(); @@ -878,7 +897,6 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); uint64 current_time = EnvTime::NowNanos(); - const string& model_name = GetModelName(last_task_context); for (int i = 0; i < batch->num_tasks(); ++i) { RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3, model_name, last_task_context->op_kernel().name(), @@ -930,15 +948,17 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { CreateCostMeasurements(batching_context); int64_t processed_size = batch->size(); - auto batch_cost_split_cleanup = gtl::MakeCleanup([&] { - SplitBatchCostsAndRecordMetrics(batch_cost_measurements, processed_size, - *batch); - }); OpKernelContext* last_task_context = batch->task(batch->num_tasks() - 1).context; AsyncOpKernel::DoneCallback last_task_callback = batch->task(batch->num_tasks() - 1).done_callback; + const std::string& model_name = GetModelName(last_task_context); + + auto batch_cost_cleanup = gtl::MakeCleanup([&] { + SplitBatchCostsAndRecordMetrics(model_name, batch_cost_measurements, + processed_size, *batch); + }); OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch), last_task_callback); @@ -1056,6 +1076,7 @@ Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, } void BatchResourceBase::SplitBatchCostsAndRecordMetrics( + const std::string& model_name, const std::vector>& batch_cost_measurements, const int64_t processed_size, BatchT& batch) { @@ -1078,6 +1099,15 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( const absl::string_view cost_type = batch_cost_measurement->GetCostType(); const absl::Duration total_cost = batch_cost_measurement->GetTotalCost(); + // Smeared batch cost: cost for processing this batch. + RecordBatchCosts(model_name, processed_size, + absl::StrCat(cost_type, kWithSmearSuffix), total_cost); + // Non-smeared batch cost: cost for processing inputs in this batch, i.e. + // cost for processing paddings is excluded. + RecordBatchCosts(model_name, processed_size, + absl::StrCat(cost_type, kNoSmearSuffix), + total_cost / processed_size * batch.size()); + for (int i = 0; i < batch.num_tasks(); i++) { RequestCost* request_cost = batch.task(i).request_cost; // Skip recording the cost if the request_cost is null. diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 5124e9f031733a..b86d25c097da39 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -238,6 +239,7 @@ class BatchResourceBase : public ResourceBase { // 2) the input size from this task; // 3) the padding amount. static void SplitBatchCostsAndRecordMetrics( + const std::string& model_name, const std::vector>& batch_cost_measurements, int64_t processed_size, BatchT& batch); diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index dc75fde050cc6f..cd4ae4644ed62e 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -70,9 +70,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoCostMeasurement) { batch.Close(); std::vector> batch_cost_measurements; - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/16, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/16, batch); EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty()); EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( @@ -90,9 +89,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroCost) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("no_op", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/16, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/16, batch); EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty()); EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( @@ -108,9 +106,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroBatchSize) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/0, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/0, batch); } TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoRequestCost) { @@ -123,9 +120,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoRequestCost) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/16, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/16, batch); EXPECT_EQ(batch.task(0).request_cost, nullptr); EXPECT_EQ(batch.task(1).request_cost, nullptr); @@ -142,9 +138,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitSingleCostType) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/20, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/20, batch); EXPECT_THAT( batch.task(0).request_cost->GetCosts(), @@ -179,9 +174,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitMultiCostTypes) { CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_gcu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/20, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/20, batch); EXPECT_THAT( batch.task(0).request_cost->GetCosts(), @@ -223,9 +217,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) { CostMeasurementRegistry::CreateByNameOrNull("no_op", context)); batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/20, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/20, batch); EXPECT_THAT( batch.task(0).request_cost->GetCosts(), diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 2b95b91b8f103e..4e218e83dd51c2 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -217,6 +217,9 @@ class SharedBatchScheduler // done by adding padding in the process-batch callback. size_t max_execution_batch_size = 1000; + // If non-empty, contains configured batch sizes. + std::vector allowed_batch_sizes; + // If true, the padding will not be appended. bool disable_padding = false; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index 454e7a77b32037..c21a0cc907bce7 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -671,7 +671,6 @@ DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor - // A dummy type to group backward filter autotune results together. struct Conv3dBackwardFilterAutotuneGroup { static string name() { return "Conv3dBwdFilter"; } @@ -702,8 +701,7 @@ void LaunchConvBackpropFilterOpImpl( OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); if (DataTypeToEnum::value == DT_BFLOAT16 && - !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + !IsBF16SupportedInOps(stream)) { context->SetStatus(errors::Unimplemented( "Conv3DBackpropFilter for GPU with bfloat16 is only supported " "with cuDNN on Ampere GPUs or later.")); @@ -803,86 +801,86 @@ void LaunchConvBackpropFilterOpImpl( << padding_planes << ")"; #if GOOGLE_CUDA - const bool compute_in_nhwc = ComputeInNhwcEnabled( - DataTypeToEnum::value, stream, /*use_4d_tensor=*/false); + const bool compute_in_nhwc = ComputeInNhwcEnabled( + DataTypeToEnum::value, stream, /*use_4d_tensor=*/false); #else - // fast NDHWC implementation is a CUDA only feature - const bool compute_in_nhwc = false; + // fast NDHWC implementation is a CUDA only feature + const bool compute_in_nhwc = false; #endif - const TensorFormat compute_data_format = - (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC - : FORMAT_NCHW; - - VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:" - << " data_format=" << ToString(data_format) - << " compute_data_format=" << ToString(compute_data_format); - - constexpr auto kComputeInNHWC = - std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, - se::dnn::FilterLayout::kOutputYXInput); - constexpr auto kComputeInNCHW = - std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, - se::dnn::FilterLayout::kOutputInputYX); - - se::dnn::DataLayout compute_data_layout; - se::dnn::FilterLayout filter_layout; - - std::tie(compute_data_layout, filter_layout) = - compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; - - se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(dims.batch_size) - .set_spatial_dim(DimIndex::X, - GetTensorDim(compatible_input, data_format, '2')) - .set_spatial_dim(DimIndex::Y, - GetTensorDim(compatible_input, data_format, '1')) - .set_spatial_dim(DimIndex::Z, - GetTensorDim(compatible_input, data_format, '0')) - .set_feature_map_count(dims.in_depth) - .set_layout(compute_data_layout); - se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(dims.batch_size) - .set_spatial_dim(DimIndex::X, dims.output_size(2)) - .set_spatial_dim(DimIndex::Y, dims.output_size(1)) - .set_spatial_dim(DimIndex::Z, dims.output_size(0)) - .set_feature_map_count(dims.out_depth) - .set_layout(compute_data_layout); - se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) - .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) - .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) - .set_input_feature_map_count(filter_shape.dim_size(3)) - .set_output_feature_map_count(filter_shape.dim_size(4)) - .set_layout(filter_layout); - se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) - .set_dilation_rate(DimIndex::Y, dims.dilation(1)) - .set_dilation_rate(DimIndex::Z, dims.dilation(0)) - .set_filter_stride(DimIndex::X, dims.stride(2)) - .set_filter_stride(DimIndex::Y, dims.stride(1)) - .set_filter_stride(DimIndex::Z, dims.stride(0)) - .set_zero_padding(DimIndex::X, padding_cols / 2) - .set_zero_padding(DimIndex::Y, padding_rows / 2) - .set_zero_padding(DimIndex::Z, padding_planes / 2) - .set_group_count(dims.in_depth / filter_shape.dim_size(3)); - - Tensor pre_transformed_filter_backprop; - auto dst_format = - compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; - TensorShape dst_shape = - dst_format == FORMAT_OIHW - ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), - dims.filter_size(0), dims.filter_size(1), - dims.filter_size(2)}) - : TensorShape({filter_shape.dim_size(4), dims.filter_size(0), - dims.filter_size(1), dims.filter_size(2), - filter_shape.dim_size(3)}); - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, dst_shape, - &pre_transformed_filter_backprop)); - - Tensor transformed_out_backprop; - if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + const TensorFormat compute_data_format = + (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; + + VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:" + << " data_format=" << ToString(data_format) + << " compute_data_format=" << ToString(compute_data_format); + + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + + std::tie(compute_data_layout, filter_layout) = + compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; + + se::dnn::BatchDescriptor input_desc(3); + input_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, + GetTensorDim(compatible_input, data_format, '2')) + .set_spatial_dim(DimIndex::Y, + GetTensorDim(compatible_input, data_format, '1')) + .set_spatial_dim(DimIndex::Z, + GetTensorDim(compatible_input, data_format, '0')) + .set_feature_map_count(dims.in_depth) + .set_layout(compute_data_layout); + se::dnn::BatchDescriptor output_desc(3); + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) + .set_layout(compute_data_layout); + se::dnn::FilterDescriptor filter_desc(3); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(filter_shape.dim_size(3)) + .set_output_feature_map_count(filter_shape.dim_size(4)) + .set_layout(filter_layout); + se::dnn::ConvolutionDescriptor conv_desc(3); + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) + .set_zero_padding(DimIndex::X, padding_cols / 2) + .set_zero_padding(DimIndex::Y, padding_rows / 2) + .set_zero_padding(DimIndex::Z, padding_planes / 2) + .set_group_count(dims.in_depth / filter_shape.dim_size(3)); + + Tensor pre_transformed_filter_backprop; + auto dst_format = + compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), + dims.filter_size(0), dims.filter_size(1), + dims.filter_size(2)}) + : TensorShape({filter_shape.dim_size(4), dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2), + filter_shape.dim_size(3)}); + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, dst_shape, + &pre_transformed_filter_backprop)); + + Tensor transformed_out_backprop; + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW."; TensorShape nchw_shape = {dims.batch_size, dims.out_depth, dims.output_size(0), dims.output_size(1), @@ -897,11 +895,11 @@ void LaunchConvBackpropFilterOpImpl( } else { CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); } - } else { + } else { transformed_out_backprop = out_backprop; - } - Tensor transformed_input; - if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + } + Tensor transformed_input; + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW."; TensorShape nchw_shape = { dims.batch_size, dims.in_depth, compatible_input.dim_size(1), @@ -917,96 +915,93 @@ void LaunchConvBackpropFilterOpImpl( } else { CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); } - } else { + } else { transformed_input = compatible_input; - } + } - auto out_backprop_ptr = - AsDeviceMemory(transformed_out_backprop.template flat().data(), - transformed_out_backprop.template flat().size()); - auto filter_backprop_ptr = AsDeviceMemory( - pre_transformed_filter_backprop.template flat().data(), - pre_transformed_filter_backprop.template flat().size()); - auto input_ptr = - AsDeviceMemory(transformed_input.template flat().data(), - transformed_input.template flat().size()); - - static int64_t ConvolveBackwardFilterScratchSize = - GetDnnWorkspaceLimitOrDefault(); - - const ConvParameters conv_parameters = { - stream->parent(), - dims.batch_size, - dims.in_depth, - {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, - compute_data_format, - dims.out_depth, - {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, - {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, - {{dims.stride(0), dims.stride(1), dims.stride(2)}}, - {{padding_planes, padding_rows, padding_cols}}, - input.dtype(), - conv_desc.group_count(), - }; - - using se::dnn::AlgorithmConfig; - using se::dnn::AlgorithmDesc; - using se::dnn::ProfileResult; - - auto entry_or = AutotuneUnfusedConv( - cudnn_use_autotune, AutotuneConv3dBwdFilter::GetInstance(), - conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER, - input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc, - output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize); - OP_REQUIRES_OK(context, entry_or.status()); - auto autotune_entry = std::move(entry_or).value(); - - DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, - context); - Status cudnn_launch_status = LaunchAutotunedConv( - autotune_entry, &scratch_allocator, - se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, - input_ptr, filter_desc, filter_backprop_ptr, conv_desc, output_desc, - out_backprop_ptr); - if (!cudnn_launch_status.ok()) { - context->SetStatus(cudnn_launch_status); - return; - } + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_backprop_ptr = + AsDeviceMemory(pre_transformed_filter_backprop.template flat().data(), + pre_transformed_filter_backprop.template flat().size()); + auto input_ptr = AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + + static int64_t ConvolveBackwardFilterScratchSize = + GetDnnWorkspaceLimitOrDefault(); + + const ConvParameters conv_parameters = { + stream->parent(), + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, + compute_data_format, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, + {{padding_planes, padding_rows, padding_cols}}, + input.dtype(), + conv_desc.group_count(), + }; + + using se::dnn::AlgorithmConfig; + using se::dnn::AlgorithmDesc; + using se::dnn::ProfileResult; + + auto entry_or = AutotuneUnfusedConv( + cudnn_use_autotune, AutotuneConv3dBwdFilter::GetInstance(), + conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER, + input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc, + output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize); + OP_REQUIRES_OK(context, entry_or.status()); + auto autotune_entry = std::move(entry_or).value(); + + DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + context); + Status cudnn_launch_status = LaunchAutotunedConv( + autotune_entry, &scratch_allocator, + se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, input_ptr, + filter_desc, filter_backprop_ptr, conv_desc, output_desc, + out_backprop_ptr); + if (!cudnn_launch_status.ok()) { + context->SetStatus(cudnn_launch_status); + return; + } - auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::ReverseTransformFilter()( - context->eigen_device(), /*src_filter_format=*/dst_format, - toConstTensor(pre_transformed_filter_backprop).template tensor(), - filter_backprop->tensor()); + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::ReverseTransformFilter()( + context->eigen_device(), /*src_filter_format=*/dst_format, + toConstTensor(pre_transformed_filter_backprop).template tensor(), + filter_backprop->tensor()); } template struct LaunchConvBackpropFilterOp { - static void launch(OpKernelContext* context, bool cudnn_use_autotune, - const Tensor& input, const Tensor& out_backprop, - const std::vector& dilation, - const std::vector& stride, const Padding& padding, - Tensor* filter_backprop, TensorFormat data_format) { - LaunchConvBackpropFilterOpImpl(context, cudnn_use_autotune, input, - out_backprop, dilation, stride, padding, - filter_backprop, data_format); - } + static void launch(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input, const Tensor& out_backprop, + const std::vector& dilation, + const std::vector& stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { + LaunchConvBackpropFilterOpImpl(context, cudnn_use_autotune, input, + out_backprop, dilation, stride, padding, + filter_backprop, data_format); + } }; template <> struct LaunchConvBackpropFilterOp { - static void launch(OpKernelContext* ctx, bool cudnn_use_autotune, - const Tensor& input, const Tensor& out_backprop, - const std::vector& dilation, - const std::vector& stride, const Padding& padding, - Tensor* filter_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. - auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); - - if (cast_to_float) { + static void launch(OpKernelContext* ctx, bool cudnn_use_autotune, + const Tensor& input, const Tensor& out_backprop, + const std::vector& dilation, + const std::vector& stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { + auto* stream = ctx->op_device_context()->stream(); + + const bool cast_to_float = !IsBF16SupportedInOps(stream); + + if (cast_to_float) { Tensor casted_input = input; Tensor casted_out_backprop = out_backprop; Tensor casted_filter_backprop = *filter_backprop; @@ -1035,96 +1030,96 @@ struct LaunchConvBackpropFilterOp { cast_back(device, filter_backprop->template flat(), casted_filter_backprop_const.template flat()); return; - } - - LaunchConvBackpropFilterOpImpl( - ctx, cudnn_use_autotune, input, out_backprop, dilation, stride, - padding, filter_backprop, data_format); } + + LaunchConvBackpropFilterOpImpl( + ctx, cudnn_use_autotune, input, out_backprop, dilation, stride, padding, + filter_backprop, data_format); + } }; template class Conv3DBackpropFilterOp : public OpKernel { - public: - explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) - : OpKernel(context), - data_format_(FORMAT_NHWC), - takes_shape_(type_string().find("V2") != std::string::npos) { - // data_format is only available in V2. - if (takes_shape_) { + public: + explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); - } - OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); - OP_REQUIRES(context, dilation_.size() == 5, - errors::InvalidArgument("Dilation rates field must " - "specify 5 dimensions")); - OP_REQUIRES(context, - (GetTensorDim(dilation_, data_format_, 'C') == 1 && - GetTensorDim(dilation_, data_format_, 'N') == 1), - errors::InvalidArgument( - "Current implementation does not yet support " - "dilation rates in the batch and depth dimensions.")); - OP_REQUIRES( - context, - (GetTensorDim(dilation_, data_format_, '0') > 0 && - GetTensorDim(dilation_, data_format_, '1') > 0 && - GetTensorDim(dilation_, data_format_, '2') > 0), - errors::InvalidArgument("Dilated rates should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 5, - errors::InvalidArgument("Sliding window strides field must " - "specify 5 dimensions")); - OP_REQUIRES(context, - (GetTensorDim(stride_, data_format_, 'C') == 1 && - GetTensorDim(stride_, data_format_, 'N') == 1), - errors::InvalidArgument( - "Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES( - context, - (GetTensorDim(stride_, data_format_, '0') > 0 && - GetTensorDim(stride_, data_format_, '1') > 0 && - GetTensorDim(stride_, data_format_, '2') > 0), - errors::InvalidArgument("Spatial strides should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - cudnn_use_autotune_ = CudnnUseAutotune(); } + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + OP_REQUIRES( + context, + (GetTensorDim(dilation_, data_format_, '0') > 0 && + GetTensorDim(dilation_, data_format_, '1') > 0 && + GetTensorDim(dilation_, data_format_, '2') > 0), + errors::InvalidArgument("Dilated rates should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, '0') > 0 && + GetTensorDim(stride_, data_format_, '1') > 0 && + GetTensorDim(stride_, data_format_, '2') > 0), + errors::InvalidArgument("Spatial strides should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + cudnn_use_autotune_ = CudnnUseAutotune(); + } - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(0); - const Tensor& out_backprop = context->input(2); + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& out_backprop = context->input(2); - TensorShape filter_shape; - if (takes_shape_) { + TensorShape filter_shape; + if (takes_shape_) { const Tensor& filter_sizes = context->input(1); OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()), errors::InvalidArgument( "filter_sizes shape must be rank 1 but is rank ", filter_sizes.shape().dims())); OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape)); - } else { + } else { filter_shape = context->input(1).shape(); - } + } - Tensor* filter_backprop; - OP_REQUIRES_OK( - context, context->allocate_output(0, filter_shape, &filter_backprop)); + Tensor* filter_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, filter_shape, &filter_backprop)); - LaunchConvBackpropFilterOp::launch( - context, cudnn_use_autotune_, input, out_backprop, dilation_, stride_, - padding_, filter_backprop, data_format_); - } + LaunchConvBackpropFilterOp::launch( + context, cudnn_use_autotune_, input, out_backprop, dilation_, stride_, + padding_, filter_backprop, data_format_); + } - private: - std::vector dilation_; - std::vector stride_; - Padding padding_; - TensorFormat data_format_; - bool takes_shape_; - bool cudnn_use_autotune_; + private: + std::vector dilation_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + bool cudnn_use_autotune_; }; #define REGISTER_GPU_KERNEL(T) \ diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index 1c1472ead96faa..e65e5995e92045 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -539,11 +539,8 @@ operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Padding& padding, const std::vector& explicit_paddings, Tensor* filter_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_input = input; diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 327855646e1f60..cf805027cb5835 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -463,11 +463,8 @@ void LaunchConv2DBackpropInputOp::operator()( int col_dilation, int row_stride, int col_stride, const Padding& padding, const std::vector& explicit_paddings, Tensor* in_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_out_backprop = out_backprop; diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index 06cf67d0fc4b50..70311cbbd7a3d7 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -657,7 +657,6 @@ TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_bfloat16(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL - // GPU definitions of both ops. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. @@ -1025,11 +1024,8 @@ struct LaunchConvBackpropInputOp { const std::vector& dilation, const std::vector& strides, const Padding& padding, Tensor* in_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_out_backprop = out_backprop; @@ -1153,15 +1149,14 @@ class Conv3DBackpropInputOp : public OpKernel { bool cudnn_use_autotune_; }; - -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint("T"), \ - Conv3DBackpropInputOp); \ - REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("input_sizes"), \ +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint("T"), \ + Conv3DBackpropInputOp); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("input_sizes"), \ Conv3DBackpropInputOp); TF_CALL_half(REGISTER_GPU_KERNEL); TF_CALL_bfloat16(REGISTER_GPU_KERNEL); diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index d932b57189b4ef..72bad756b4d0fd 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -227,11 +227,9 @@ struct LaunchConv3DOp { strides.end()); gtl::InlinedVector casted_dilations(dilations.begin(), dilations.end()); - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. + auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_input = input_param; diff --git a/tensorflow/core/kernels/conv_ops_bfloat16.cc b/tensorflow/core/kernels/conv_ops_bfloat16.cc index 918c17c0f31b02..37507841647f0b 100644 --- a/tensorflow/core/kernels/conv_ops_bfloat16.cc +++ b/tensorflow/core/kernels/conv_ops_bfloat16.cc @@ -118,11 +118,8 @@ void LaunchConvOp::operator()( dilations_spatial[i] = GetTensorDim(dilations, data_format, static_cast(i + '0')); } - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_input = input; @@ -173,11 +170,8 @@ void LaunchConv2DOp::operator()( gtl::InlinedVector casted_dilations = {row_dilation, col_dilation}; - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_input = input_param; diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc index dce0e995be7581..bd6e9ed054762a 100644 --- a/tensorflow/core/kernels/cudnn_pooling_gpu.cc +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc @@ -149,11 +149,9 @@ void DnnPooling3dOp::Compute( const std::array& window, const std::array& stride, const std::array& padding, TensorFormat data_format, const Tensor& tensor_in, Tensor* output) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); + if (cast_to_float) { Tensor casted_in; Tensor casted_output; @@ -348,11 +346,8 @@ void DnnPooling3dGradOp::Compute( const std::array& output_size, TensorFormat data_format, const Tensor& out_backprop, const TensorShape& tensor_in_shape, const Tensor* tensor_in, const Tensor* tensor_out, Tensor* input_backprop) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_out_backprop; Tensor casted_tensor_in; diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 06f1bf2bb3c531..16812fc1a59342 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1,11 +1,11 @@ # Description: # OpKernels for tf.data -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") # Definitions are loaded separately so that copybara can pattern match (and modify) each definition. load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_kernel_library") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -996,6 +996,7 @@ tf_kernel_library( "//tensorflow/core/data:name_utils", "//tensorflow/core/data:split_utils", "@com_google_absl//absl/memory", + "@local_tsl//tsl/platform:types", ], ) @@ -1064,6 +1065,8 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/data:name_utils", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:errors", ], ) @@ -1172,7 +1175,9 @@ tf_kernel_library( "//tensorflow/core/data:dataset_utils", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:serialization_utils", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", ], ) @@ -1512,6 +1517,8 @@ filegroup( "//tensorflow/core/data:captured_function.h", "//tensorflow/core/data:compression_utils.h", "//tensorflow/core/data:dataset_utils.h", + "//tensorflow/core/data:file_logger_client_interface.h", + "//tensorflow/core/data:file_logger_client_no_op.h", "//tensorflow/core/data:finalization_utils.h", "//tensorflow/core/data:metric_utils.h", "//tensorflow/core/data:name_utils.h", diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 768fd7c6cf0e0d..3ffe6bfc78eb76 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -919,6 +919,8 @@ tf_kernel_library( ":data_service_dataset_op", ":data_service_ops", ":distributed_save_op", + "//tensorflow/core/data/service/snapshot:list_snapshot_chunks_dataset_op", + "//tensorflow/core/data/service/snapshot:snapshot_chunk_dataset_op", ], ) @@ -964,7 +966,6 @@ tf_kernel_library( ":to_tf_record_op", ":unbatch_dataset_op", ":unique_dataset_op", - "//tensorflow/core/data/service/snapshot:snapshot_chunk_dataset_op", ] + select({ "//tensorflow:fuchsia": [], "//conditions:default": [":lmdb_dataset_op"], diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 2a955123681ca6..05035da404cc99 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -495,9 +495,16 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { // time of the node. If restoring, pass nullptr to not record processing // time because iterator modeling is only used to model Iterator's // GetNext() resource usage. - TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run( + auto status = instantiated_reduce_func_->Run( ctx, std::move(args), &return_values, - ctx->is_restoring() ? nullptr : model_node())); + ctx->is_restoring() ? nullptr : model_node()); + if (!status.ok()) { + return absl::InternalError(absl::StrFormat( + "Got error code %s and message: {\n%s\n} \nfrom running " + "user-defined function %s: ", + absl::StatusCodeToString(status.code()), status.message(), + instantiated_reduce_func_->func_name())); + } if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index b6379c598fbd31..9abee3f1112296 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -139,6 +139,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + // LINT.IfChange(GetNextInternal) mutex_lock l(mu_); do { if (!input_impl_) { @@ -149,28 +150,39 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { // We are currently processing a mapped element, so try to get the // next subelement. bool end_of_element; + // Create a new context so that we have a separate `checkpoint` + // different from `ctx->checkpoint()` auto nested_ctx = MakeNestedIteratorContext(ctx); TF_RETURN_IF_ERROR(current_element_iterator_->GetNext( &nested_ctx, out_tensors, &end_of_element)); + + // Merge the checkpoint so that the changes made to + // `current_element_iterator_` is propagated ctx->MergeCheckpoint(nested_ctx.checkpoint()); if (!end_of_element) { // Produce the subelement as output. *end_of_sequence = false; return OkStatus(); } + // Since this sub-iterator is done, + // we can commit `input_ckpt_` to `ctx->checkpoint()` ctx->MergeCheckpoint(input_ckpt_.get()); + // Also clean up this sub-iterator's checkpoint inside of + // `ctx->checkpoint()` since it has been consumed. + ctx->PurgeCheckpoint(current_element_iterator_->prefix()); // We have reached the end of the current element, so maybe move on // to the next element. - ctx->PurgeCheckpoint(current_element_iterator_->prefix()); current_element_iterator_.reset(); } - // Get the next element from the input dataset. inputs_.clear(); auto input_ctx = std::make_unique(*ctx); TF_RETURN_IF_ERROR( input_impl_->GetNext(input_ctx.get(), &inputs_, end_of_sequence)); + // Merge the checkpoint to `input_ckpt_` but do not commit to + // `ctx->checkpoint()` yet until the sub-iterator created from + // this `inputs_` is consumed. input_ckpt_->Merge(input_ctx->checkpoint()); if (*end_of_sequence) { input_impl_.reset(); @@ -180,10 +192,12 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/true)); } while (true); + // LINT.ThenChange(:SkipInternal) } Status SkipInternal(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence, int* num_skipped) override { + // LINT.IfChange(SkipInternal) mutex_lock l(mu_); *num_skipped = 0; while (*num_skipped < num_to_skip) { @@ -191,33 +205,65 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { *end_of_sequence = true; return OkStatus(); } - if (!current_element_iterator_) { - // Get the next element from the input dataset. - inputs_.clear(); - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, &inputs_, end_of_sequence)); - if (*end_of_sequence) { - input_impl_.reset(); - *end_of_sequence = true; - return OkStatus(); + if (current_element_iterator_) { + // We are currently processing a mapped element, so try to get the + // next subelement. + + bool end_of_element; + // Create a new context so that we have a separate `checkpoint` + // different from `ctx->checkpoint()` + auto nested_ctx = MakeNestedIteratorContext(ctx); + + // `last_num_skipped` stores how many elements + // we have actually skipped. + int last_num_skipped; + TF_RETURN_IF_ERROR(current_element_iterator_->Skip( + &nested_ctx, num_to_skip - *num_skipped, &end_of_element, + &last_num_skipped)); + *num_skipped += last_num_skipped; + + // Merge the checkpoint so that the changes made to + // `current_element_iterator_` is propagated + ctx->MergeCheckpoint(nested_ctx.checkpoint()); + if (!end_of_element) { + if (*num_skipped != num_to_skip) { + return absl::InternalError(absl::StrFormat( + "Expected `num_skipped` and `num_to_skip` to be the same. Got" + " %d(num_skipped) and %d(num_to_skip)", + *num_skipped, num_to_skip)); + } + continue; } - TF_RETURN_IF_ERROR( - BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false)); - } - bool end_of_element; - int last_num_skipped; - TF_RETURN_IF_ERROR(current_element_iterator_->Skip( - MakeNestedIteratorContext(ctx), num_to_skip - *num_skipped, - &end_of_element, &last_num_skipped)); - *num_skipped += last_num_skipped; - if (end_of_element) { + // Since this sub-iterator is done, + // we can commit `input_ckpt_` to `ctx->checkpoint()` + ctx->MergeCheckpoint(input_ckpt_.get()); + // Also clean up this sub-iterator's checkpoint inside of + // `ctx->checkpoint()` since it has been consumed. + ctx->PurgeCheckpoint(current_element_iterator_->prefix()); // We have reached the end of the current element, so maybe move on // to the next element. current_element_iterator_.reset(); } + // Get the next element from the input dataset. + inputs_.clear(); + auto input_ctx = std::make_unique(*ctx); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(input_ctx.get(), &inputs_, end_of_sequence)); + // Merge the checkpoint to `input_ckpt_` but do not commit to + // `ctx->checkpoint()` yet until the sub-iterator created from + // this `inputs_` is consumed. + input_ckpt_->Merge(input_ctx->checkpoint()); + if (*end_of_sequence) { + input_impl_.reset(); + *end_of_sequence = true; + return OkStatus(); + } + TF_RETURN_IF_ERROR( + BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false)); } *end_of_sequence = false; return OkStatus(); + // LINT.ThenChange(:GetNextInternal) } protected: diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 2ae28298d81f27..41e98622f18e5f 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -226,7 +226,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { {"cycle_length", strings::Printf("%lld", static_cast(cycle_length))}, {"deterministic", - deterministic.IsNondeterministic() ? "false" : "true"}}) { + deterministic.IsNondeterministic() ? "false" : "true"}, + {"buffer_output_elements", + strings::Printf("%lld", + static_cast(buffer_output_elements_))}, + {"prefetch_input_elements", + strings::Printf( + "%lld", static_cast(prefetch_input_elements_))}}) { input_->Ref(); } diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index 35313d78f26d09..b6aa84aa7089f5 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" +#include "tsl/platform/types.h" namespace tensorflow { namespace data { @@ -65,6 +66,25 @@ Status ConvertOutputTypes(const tensorflow::DataTypeVector& output_dtypes, int64_t sgn(int64_t val) { return (0 < val) - (val < 0); } +int64_t RangeCardinality(int64_t start, int64_t stop, int64_t step) { + // `enumerate` uses int max to simulate an infinite range dataset. + if (stop >= tsl::kint64max) { + return kInfiniteCardinality; + } + + // If the signs of `stop - start` and `step` are different or either of + // the values is zero, the range will be empty. + if (sgn(stop - start) * sgn(step) <= 0) { + return 0; + } else if (step > 0) { + // Invariant: stop - start > 0 && step > 0 + return (stop - start - 1) / step + 1; + } else { + // Invariant: start - stop > 0 && step < 0 + return (start - stop - 1) / -step + 1; + } +} + // Class which produces the elements of `range(start, stop, step)`. Threadsafe. class RangeCounter { public: @@ -100,6 +120,8 @@ class RangeCounter { next_ = value; } + int64_t Cardinality() const { return RangeCardinality(start_, stop_, step_); } + private: const int64_t start_; const int64_t stop_; @@ -147,6 +169,8 @@ class RangeDatasetOp::RangeSplitProvider : public SplitProvider { return OkStatus(); } + int64_t Cardinality() const override { return counter_.Cardinality(); } + private: RangeCounter counter_; }; @@ -185,17 +209,7 @@ class RangeDatasetOp::Dataset : public DatasetBase { } int64_t CardinalityInternal(CardinalityOptions options) const override { - // If the signs of `stop_ - start_` and `step_` are different or either of - // the values is zero, the range will be empty. - if (sgn(stop_ - start_) * sgn(step_) <= 0) { - return 0; - } else if (step_ > 0) { - // Invariant: stop_ - start_ > 0 && step_ > 0 - return (stop_ - start_ - 1) / step_ + 1; - } else { - // Invariant: start_ - stop_ > 0 && step_ < 0 - return (start_ - stop_ - 1) / -step_ + 1; - } + return RangeCardinality(start_, stop_, step_); } Status MakeSplitProviders(std::vector>* diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 819f3bf087a66a..2e66c311d43b89 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/repeat_dataset_op.h" +#include +#include #include #include #include +#include "absl/status/status.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace data { @@ -74,6 +78,49 @@ bool HasDataServiceInput(const DatasetBase* dataset) { } } // namespace +// Updates an input split provider with the appropriate cardinality count based +// on how many times it is repeated. +class RepeatedSplitProvider : public SplitProvider { + public: + explicit RepeatedSplitProvider(std::unique_ptr split_provider, + int64_t count) + : split_provider_(std::move(split_provider)), count_(count) {} + + // Updates the cardinality based on the times the input dataset is repeated. + int64_t Cardinality() const override { + if (split_provider_->Cardinality() == 0 || count_ == 0) { + return 0; + } + // From tensorflow/python/data/ops/repeat_op.py, the repeat op uses -1 for + // infinite repetitions. + if (count_ < 0) { + return kInfiniteCardinality; + } + if (split_provider_->Cardinality() < 0) { + return split_provider_->Cardinality(); + } + return split_provider_->Cardinality() * count_; + } + + // The following are the same as the input split provider. + absl::Status GetNext(Tensor* split, bool* end_of_splits) override { + return split_provider_->GetNext(split, end_of_splits); + } + absl::Status Reset() override { return split_provider_->Reset(); } + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override { + return split_provider_->Save(full_name, writer); + } + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override { + return split_provider_->Restore(full_name, reader); + } + + private: + const std::unique_ptr split_provider_; + const int64_t count_; +}; + class RepeatDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input) @@ -97,6 +144,19 @@ class RepeatDatasetOp::Dataset : public DatasetBase { } } + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { + std::vector> input_split_providers; + TF_RETURN_IF_ERROR(input_->MakeSplitProviders(&input_split_providers)); + + split_providers->clear(); + for (auto& split_provider : input_split_providers) { + split_providers->push_back(std::make_unique( + std::move(split_provider), count_)); + } + return absl::OkStatus(); + } + const DataTypeVector& output_dtypes() const override { return input_->output_dtypes(); } diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 5143182f0b1a90..cb2c28dbf1ea58 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/serialization_utils.h" @@ -203,6 +205,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { mutex_lock l(mu_); seed_generator_->GenerateSeeds(&seed_, &seed2_); ResetRngs(); + // Initialize checkpoint_indices_ to the entire buffer. + if (ctx->symbolic_checkpoint()) { + for (int64_t i = 0; i < buffer_->size(); ++i) { + checkpoint_indices_.insert(i); + } + } return OkStatus(); } @@ -229,6 +237,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { this->RecordBufferDequeue(ctx, *out_tensors); std::swap(buffer_->at(index), buffer_->at(slices_.front()->start % buffer_->size())); + checkpoint_indices_.insert(index); + checkpoint_indices_.insert(slices_.front()->start % buffer_->size()); slices_.front()->start++; num_elements_--; return OkStatus(); @@ -273,8 +283,20 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kEpoch, epoch_)); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kNumElements, num_elements_)); - TF_RETURN_IF_ERROR(WriteElementsToCheckpoint( - writer, absl::StrCat(prefix(), kColon, "buffer"), *buffer_)); + const std::string key_prefix = absl::StrCat(prefix(), kColon, "buffer"); + if (ctx->symbolic_checkpoint()) { + // When symbolic checkpointing is turned on, `writer` + // already contains checkpoint of the shuffle buffer created by the + // previous invocation of this instance and the indices that need to be + // updated are stored in `checkpoint_indices`. + TF_RETURN_IF_ERROR(UpdateCheckpointElements( + writer, key_prefix, *buffer_, checkpoint_indices_)); + checkpoint_indices_.clear(); + } else { + TF_RETURN_IF_ERROR( + WriteElementsToCheckpoint(writer, key_prefix, *buffer_)); + } + TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kSlicesSize, slices_.size())); for (size_t i = 0; i < slices_.size(); ++i) { @@ -339,6 +361,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(ReadElementsFromCheckpoint( ctx, reader, absl::StrCat(prefix(), kColon, "buffer"), buffer_.get())); + if (ctx->symbolic_checkpoint()) { + DCHECK(checkpoint_indices_.empty()); + for (size_t i = 0; i < buffer_->size(); ++i) { + checkpoint_indices_.insert(i); + } + } for (const auto& element : *buffer_) { RecordBufferEnqueue(ctx, element); } @@ -502,9 +530,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { this->RecordBufferEnqueue(ctx, element); if (num_elements_ == buffer_->size()) { DCHECK(IsShuffleAll()); + checkpoint_indices_.insert(buffer_->size()); buffer_->push_back(element); } else { size_t index = slices_.back()->end % buffer_->size(); + checkpoint_indices_.insert(index); buffer_->at(index) = std::move(element); } num_elements_++; @@ -530,6 +560,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_); // Not owned. std::unique_ptr>> buffer_ TF_GUARDED_BY(mu_); + // Holds the indices of `buffer_` that have changed since the previous + // `SaveInternal()` and need to be updated in the MemoryCheckpoint + // (if symbolic checkpointing is used) in the next `SaveInternal()`. + absl::flat_hash_set checkpoint_indices_ TF_GUARDED_BY(mu_); std::unique_ptr input_impl_ TF_GUARDED_BY(mu_) = nullptr; int64_t epoch_ TF_GUARDED_BY(mu_) = 0; int64_t num_elements_ TF_GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 0708e70481c594..c15b014815705c 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -58,16 +58,14 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; bool UseCudnnWith16BitFloat(OpKernelContext* ctx, DataType dtype) { -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (dtype == DT_HALF) { return true; } else if (dtype == DT_BFLOAT16) { auto* stream = ctx->op_device_context()->stream(); - if (!stream) return false; - return stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + return IsBF16SupportedInOps(stream); } -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM return false; } diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index e360d5f5a7d653..009710f113fd7a 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -25,6 +25,7 @@ limitations under the License. #endif // GOOGLE_CUDA #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/stream_executor_util.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -1045,11 +1046,9 @@ struct FusedBatchNorm { Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, Tensor* saved_inv_var, TensorFormat tensor_format, bool use_reserved_space) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); + if (cast_to_float) { Tensor casted_x = x; Tensor casted_side_input; @@ -1311,11 +1310,8 @@ struct FusedBatchNormGrad { Tensor* x_backprop, Tensor* scale_backprop, Tensor* offset_backprop, Tensor* side_input_backprop, bool use_reserved_space, TensorFormat tensor_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_y_backprop = y_backprop; Tensor casted_x = x; diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h index c264925055286f..9a50882a1016c7 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.h +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -26,6 +26,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ #define TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ +#include + #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -103,6 +105,22 @@ struct Relu6 { }; }; +// Applies `Tanh` to the passed input expression. +struct Tanh { + template + static auto apply(XprType expr) -> decltype(expr.tanh()) { + return expr.tanh(); + }; +}; + +// Applies `Sigmoid` to the passed input expression. +struct Sigmoid { + template + static auto apply(XprType expr) -> decltype(expr.sigmoid()) { + return expr.sigmoid(); + }; +}; + // Applies `Elu` to the passed input expression. struct Elu { template @@ -142,6 +160,8 @@ struct BiasAddArgs { return fusion == FusedComputationType::kBiasAdd || fusion == FusedComputationType::kBiasAddWithRelu || fusion == FusedComputationType::kBiasAddWithRelu6 || + fusion == FusedComputationType::kBiasAddWithTanh || + fusion == FusedComputationType::kBiasAddWithSigmoid || fusion == FusedComputationType::kBiasAddWithElu || fusion == FusedComputationType::kBiasAddWithLeakyRelu; } @@ -219,10 +239,16 @@ struct BiasAddOutputKernel { typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); for (int col = 0; col < num_cols; ++col) { - T* output_base = &output_mapper(0, col); - typename TTypes::UnalignedTensor output(output_base, num_rows); - const auto expr = output + bias; - output = Activation::template apply(expr); + Scalar* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + if constexpr (std::is_same_v) { + const auto expr = output + bias; + output = Activation::template apply(expr); + } else { + const auto bias_expr = bias.template cast(); + const auto expr = output + bias_expr; + output = Activation::template apply(expr); + } } } @@ -246,10 +272,18 @@ struct BiasAddOutputKernel { typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); for (int col = 0; col < num_cols; ++col) { - T* output_base = &output_mapper(0, col); - typename TTypes::UnalignedTensor output(output_base, num_rows); - const auto expr = output + bias; - output = LeakyRelu::template apply(expr, leakyrelu_alpha); + Scalar* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + if constexpr (std::is_same_v) { + const auto expr = output + bias; + output = + LeakyRelu::template apply(expr, leakyrelu_alpha); + } else { + const auto bias_expr = bias.template cast(); + const auto expr = output + bias_expr; + output = + LeakyRelu::template apply(expr, leakyrelu_alpha); + } } } @@ -356,6 +390,10 @@ using WithBiasAddAndRelu = BiasAddOutputKernel; template using WithBiasAddAndRelu6 = BiasAddOutputKernel; template +using WithBiasAddAndTanh = BiasAddOutputKernel; +template +using WithBiasAddAndSigmoid = BiasAddOutputKernel; +template using WithBiasAddAndElu = BiasAddOutputKernel; template using WithBiasAddAndLeakyRelu = BiasAddOutputKernel; diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index f9b9868579a4af..6f578f5f7d124d 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -37,6 +37,21 @@ using xla::AutotuningLog; using xla::ComputeCapability; using xla::CudnnVersion; +bool IsBF16SupportedInOps(se::Stream* stream) { + if (!stream) { + return false; // No stream: don't know whether it's supported. + } +#if GOOGLE_CUDA + // Performant bfloat16 operations are supported for Ampere+ GPUs. For + // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. + return stream->GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE); +#elif TENSORFLOW_USE_ROCM + // So far, we return false meaning that the conversion to float is needed. + return false; +#endif +} + bool RedzoneCheckDisabled() { const char* disable_rz_str = std::getenv("TF_DISABLE_RZ_CHECK"); return disable_rz_str != nullptr && std::strcmp(disable_rz_str, "1") == 0; diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index 96af46697a859b..8d511859ac5768 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -42,6 +42,10 @@ class AutotuneResult; namespace tensorflow { +// Returns true if bfloat16 is directly supported in Ops and inputs shall not be +// casted to floats to perform the computations and then back. +bool IsBF16SupportedInOps(se::Stream* stream); + class NodeDef; using xla::AutotuneResult; diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 99c13063933250..da5c6718f4a271 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -471,6 +471,7 @@ struct EinsumHelper { ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, /*adj_y=*/false, trans_x, trans_y, + /*grad_x=*/false, /*grad_y=*/false, bcast, &output_reshaped); return OkStatus(); } diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index 872aa9247bcb51..f937a06016dc8c 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -32,15 +32,20 @@ limitations under the License. #define EIGEN_USE_GPU #endif // GOOGLE_CUDA +#include #include +#include #include #include +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/fused_eigen_output_kernels.h" #include "tensorflow/core/platform/errors.h" @@ -85,14 +90,16 @@ struct LaunchFusedMatMulOp { template struct LaunchFusedMatMulOp { + // Use F32 compute for F16 inputs on CPU to preserve precision and reduce + // excessive casting during intermediate computations. + using ComputeType = + std::conditional_t::value == DT_HALF, float, T>; + void operator()( OpKernelContext* context, const Tensor& a, const Tensor& b, const Eigen::array, 1>& dim_pair, FusedComputationType fusion, const FusedComputationArgs& fusion_args, Tensor* output, bool use_autotune) { - OP_REQUIRES(context, DataTypeToEnum::value != DT_HALF, - errors::InvalidArgument("_FusedMatMul doesn't support DT_HALF " - "data type on CPU devices.")); auto lhs = a.matrix(); auto rhs = b.matrix(); auto out = output->matrix(); @@ -104,13 +111,21 @@ struct LaunchFusedMatMulOp { auto executeWithOutputKernel = [&](auto output_kernel) { OutputKernelWrapper output_kernel_wrapper( [&output_kernel]( - const ContractionOutputMapper& output_mapper, + const ContractionOutputMapper& + output_mapper, const Eigen::TensorContractionParams& params, Eigen::Index i, Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) { output_kernel(output_mapper, params, i, j, num_rows, num_cols); }); - out.device(d) = lhs.contract(rhs, dim_pair, output_kernel_wrapper); + if constexpr (std::is_same_v) { + out.device(d) = lhs.contract(rhs, dim_pair, output_kernel_wrapper); + } else { + out.device(d) = lhs.template cast() + .contract(rhs.template cast(), + dim_pair, output_kernel_wrapper) + .template cast(); + } }; BiasAddArgs bias_add_args; @@ -133,6 +148,12 @@ struct LaunchFusedMatMulOp { case FusedComputationType::kBiasAddWithRelu6: executeWithOutputKernel(WithBiasAddAndRelu6(bias_add_args)); break; + case FusedComputationType::kBiasAddWithTanh: + executeWithOutputKernel(WithBiasAddAndTanh(bias_add_args)); + break; + case FusedComputationType::kBiasAddWithSigmoid: + executeWithOutputKernel(WithBiasAddAndSigmoid(bias_add_args)); + break; case FusedComputationType::kBiasAddWithElu: executeWithOutputKernel(WithBiasAddAndElu(bias_add_args)); break; @@ -155,16 +176,16 @@ struct LaunchFusedMatMulOp { // We do not pass std::function directly as an output kernel because it blows // up the binary size in debug mode with super long symbol names. struct OutputKernelWrapper { - using OutputKernelFn = - std::function&, - const Eigen::TensorContractionParams&, Eigen::Index, - Eigen::Index, Eigen::Index, Eigen::Index)>; + using OutputKernelFn = std::function&, + const Eigen::TensorContractionParams&, Eigen::Index, Eigen::Index, + Eigen::Index, Eigen::Index)>; explicit OutputKernelWrapper(OutputKernelFn fn) : output_kernel_fn(std::move(fn)) {} void operator()( - const ContractionOutputMapper& output_mapper, + const ContractionOutputMapper& output_mapper, const Eigen::TensorContractionParams& params, Eigen::Index i, Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const { output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols); @@ -611,6 +632,8 @@ class FusedMatMulOp : public OpKernel { {FCT::kBiasAdd, {"BiasAdd"}}, {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, + {FCT::kBiasAddWithTanh, {"BiasAdd", "Tanh"}}, + {FCT::kBiasAddWithSigmoid, {"BiasAdd", "Sigmoid"}}, {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}, {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}}, }; @@ -711,6 +734,7 @@ class FusedMatMulOp : public OpKernel { FusedMatMulOp); TF_CALL_float(REGISTER_FUSED_CPU_MATMUL); +TF_CALL_half(REGISTER_FUSED_CPU_MATMUL); #undef REGISTER_FUSED_CPU_MATMUL diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 71f338f266bf55..7180fb1d4e35f9 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -21,12 +21,15 @@ limitations under the License. #define EIGEN_USE_THREADS #include +#include #include #include #include #include +#include "Eigen/Core" // from @eigen_archive #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -36,6 +39,7 @@ limitations under the License. #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/matmul_autotune.h" @@ -410,7 +414,8 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, bool grad_x, bool grad_y, + const MatMulBCast& bcast, Tensor* out) { typedef ParallelMatMulKernel::IsComplex> ParallelMatMulKernel; bool conjugate_result = false; @@ -539,7 +544,8 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, bool grad_x, bool grad_y, + const MatMulBCast& bcast, Tensor* out) { se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, se::blas::Transpose::kConjugateTranspose}; @@ -582,6 +588,16 @@ struct LaunchBatchMatMul { std::is_same_v; using Coefficient = std::conditional_t; + se::blas::CallContext call_context = se::blas::CallContext::kNone; + OP_REQUIRES(context, grad_x == false || grad_y == false, + errors::InvalidArgument( + "At least 1 of grad_x and grad_y shall be false")); + if (grad_x) { + call_context = se::blas::CallContext::kBackpropInput1; + } + if (grad_y) { + call_context = se::blas::CallContext::kBackpropInput2; + } #if GOOGLE_CUDA || TF_HIPBLASLT static const bool use_autotune = MatmulAutotuneEnable(); bool bCublasLtSupport = true; @@ -711,8 +727,7 @@ struct LaunchBatchMatMul { static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, - GetNumericOptions(), &scratch_allocator, - se::blas::CallContext::kNone) + GetNumericOptions(), &scratch_allocator, call_context) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -811,17 +826,16 @@ struct LaunchBatchMatMul { blas_transpose_b, blas_transpose_a, n, m, k, *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]), adj_x || trans_x ? m : k, c_ptrs[0], n, - GetNumericOptions(), se::blas::CallContext::kNone)); + GetNumericOptions(), call_context)); } else if (use_strided_batched) { OP_REQUIRES_OK( - context, - stream->ThenBlasGemmStridedBatched( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), *b_ptrs[0], - adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], - adj_x || trans_x ? m : k, a_stride, - static_cast(0.0), c_ptrs[0], n, c_stride, - batch_size, GetNumericOptions(), se::blas::CallContext::kNone)); + context, stream->ThenBlasGemmStridedBatched( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *b_ptrs[0], + adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], + adj_x || trans_x ? m : k, a_stride, + static_cast(0.0), c_ptrs[0], n, c_stride, + batch_size, GetNumericOptions(), call_context)); } else { BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = @@ -831,8 +845,7 @@ struct LaunchBatchMatMul { static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, - GetNumericOptions(), &scratch_allocator, - se::blas::CallContext::kNone) + GetNumericOptions(), &scratch_allocator, call_context) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -850,6 +863,32 @@ struct LaunchBatchMatMul { #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +inline void FastConvertToFloat(const T* src, float* dst, int64_t size) { + Eigen::Map> src_eigen(src, size); + Eigen::Map dst_eigen(dst, size); + dst_eigen = src_eigen.template cast(); +} + +template +inline void FastConvertFromFloat(const float* src, T* dst, int64_t size) { + Eigen::Map src_eigen(src, size); + Eigen::Map> dst_eigen(dst, size); + dst_eigen = src_eigen.template cast(); +} + +template <> +inline void FastConvertToFloat(const bfloat16* src, float* dst, + int64_t size) { + BFloat16ToFloat(src, dst, size); +} + +template <> +inline void FastConvertFromFloat(const float* src, bfloat16* dst, + int64_t size) { + FloatToBFloat16(src, dst, size); +} + template class BaseBatchMatMulOp : public OpKernel { public: @@ -862,11 +901,15 @@ class BaseBatchMatMulOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_)); adj_x_ = false; adj_y_ = false; + OP_REQUIRES_OK(context, context->GetAttr("grad_a", &grad_input_1_)); + OP_REQUIRES_OK(context, context->GetAttr("grad_b", &grad_input_2_)); } else { OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); trans_x_ = false; trans_y_ = false; + OP_REQUIRES_OK(context, context->GetAttr("grad_x", &grad_input_1_)); + OP_REQUIRES_OK(context, context->GetAttr("grad_y", &grad_input_2_)); } } @@ -931,8 +974,17 @@ class BaseBatchMatMulOp : public OpKernel { out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), errors::Internal("Failed to reshape output from ", out->shape().DebugString())); - if (std::is_same_v && std::is_same_v && - std::is_same_v) { + + // b/307285203: There seems to be an overly aggressive compiler optimization + // that optimizes away these data pointers unless we explicitly check them. + OP_REQUIRES(ctx, + in0_reshaped.data() != nullptr && + in1_reshaped.data() != nullptr && + out_reshaped.data() != nullptr, + absl::InternalError("Null data pointer encountered.")); + if constexpr (std::is_same_v && std::is_same_v && + (std::is_same_v || + std::is_same_v)) { Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(), &in0_reshaped_float)); @@ -941,31 +993,32 @@ class BaseBatchMatMulOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(), &out_reshaped_float)); - // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU. - BFloat16ToFloat(in0_reshaped.flat().data(), - in0_reshaped_float.flat().data(), - in0_reshaped.NumElements()); - BFloat16ToFloat(in1_reshaped.flat().data(), - in1_reshaped_float.flat().data(), - in1_reshaped.NumElements()); + // TODO: Avoid extra copy to make (b)float16 matmul efficient on CPU. + FastConvertToFloat(in0_reshaped.flat().data(), + in0_reshaped_float.flat().data(), + in0_reshaped.NumElements()); + FastConvertToFloat(in1_reshaped.flat().data(), + in1_reshaped_float.flat().data(), + in1_reshaped.NumElements()); LaunchBatchMatMul::Launch( ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_, - trans_y_, bcast, &out_reshaped_float); - FloatToBFloat16(out_reshaped_float.flat().data(), - out_reshaped.flat().data(), out->NumElements()); + trans_y_, grad_input_1_, grad_input_2_, bcast, &out_reshaped_float); + FastConvertFromFloat(out_reshaped_float.flat().data(), + out_reshaped.flat().data(), + out->NumElements()); } else { // Cast tensor to desired type to reuse Eigen. // TODO(b/178749687): remove this cast if Eigen supports this natively. - if (!std::is_same::value) { + if constexpr (!std::is_same::value) { in0_reshaped = CastTensor(in0_reshaped); } - if (!std::is_same::value) { + if constexpr (!std::is_same::value) { in1_reshaped = CastTensor(in1_reshaped); } - LaunchBatchMatMul::Launch(ctx, in0_reshaped, in1_reshaped, - adj_x_, adj_y_, trans_x_, - trans_y_, bcast, &out_reshaped); + LaunchBatchMatMul::Launch( + ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, trans_x_, trans_y_, + grad_input_1_, grad_input_2_, bcast, &out_reshaped); } } @@ -979,6 +1032,8 @@ class BaseBatchMatMulOp : public OpKernel { bool adj_y_ = false; bool trans_x_ = false; bool trans_y_ = false; + bool grad_input_1_ = false; + bool grad_input_2_ = false; // Cast `t` from `SrcT` to `DstT`. template diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index 9d4276c39c2d10..96c37ac97817b8 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -416,12 +416,7 @@ REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest, // MatMul1x256x1WithActivation); // TODO(ezhulenev): Add support for more data types. -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM using FusedBiasAddDataTypes = ::testing::Types; -#else -// CPU doesn't support more data types. -using FusedBiasAddDataTypes = ::testing::Types; -#endif INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedMatMulWithBiasOpTest, FusedBiasAddDataTypes); diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index b91db8d4cd3273..930de6e25ed604 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -166,6 +166,8 @@ StatusOr GetPlanAndAlgorithms( .beta = 0.0, .compute_precision = se::blas::kDefaultComputePrecision, .algorithm = {}, + .grad_x = false, + .grad_y = false, .compute_type = computation_type, }; diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 84486cc1abb9d3..ead6367f5bc3ce 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -1660,6 +1660,10 @@ class MklFusedConvOp OP_REQUIRES(context, num_args == 1, absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); + } else if (fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, dnnl::algorithm::eltwise_hardswish, + 1.0 / 6.0, 0.5); } else if (fused_ops == std::vector{"BiasAdd", "Add"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); @@ -1831,6 +1835,10 @@ class MklFusedDepthwiseConvOp } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); + } else if (fused_ops == std::vector{"BiasAdd", "_FusedHardSwish"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, dnnl::algorithm::eltwise_hardswish, + 1.0 / 6.0, 0.5); } else { OP_REQUIRES(context, false, absl::InvalidArgumentError( diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index db697ffe26f6b2..cb43658770bc9d 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -527,7 +527,6 @@ tf_cuda_cc_test( shard_count = 20, tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # b/173033461 - "no_rocm", # failed since 7de9cf4 ], deps = [ ":base_binary_ops_test", diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc index 9ea628a8673c94..561ca57c67ca11 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -696,11 +696,13 @@ GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( /*test_name=*/UInt64, uint64_t, uint64_t, test::DefaultInput(), test::DefaultInputNonZero(), baseline_floor_mod, test::OpsTestConfig().ExpectStrictlyEqual()); +#if !TENSORFLOW_USE_ROCM GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( FloorMod, /*test_name=*/Half, Eigen::half, Eigen::half, test::DefaultInput(), test::DefaultInputNonZero(), baseline_floor_mod, test::OpsTestConfig()); +#endif GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( FloorMod, /*test_name=*/Float, float, float, test::DefaultInput(), diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 407d6991608c7e..8d1220de4d4e0e 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -462,8 +462,7 @@ void DnnPoolingOp::Compute( context->allocate_output(0, tensor_out_shape, &tensor_out)); auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_tensor_in; Tensor casted_tensor_out; @@ -876,8 +875,7 @@ void DnnPoolingGradOp::Compute( OP_REQUIRES_OK(context, context->allocate_output(0, tensor_in_shape, &input_backprop)); auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = !IsBF16SupportedInOps(stream); if (cast_to_float) { Tensor casted_tensor_in; Tensor casted_tensor_out; diff --git a/tensorflow/core/kernels/ragged_cross_op.cc b/tensorflow/core/kernels/ragged_cross_op.cc index 71deb58c3c12d0..c8f27051b449cc 100644 --- a/tensorflow/core/kernels/ragged_cross_op.cc +++ b/tensorflow/core/kernels/ragged_cross_op.cc @@ -22,10 +22,12 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/ragged_utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/util/util.h" #include "tensorflow/core/util/work_sharder.h" +#include "tsl/platform/errors.h" namespace tensorflow { @@ -392,28 +394,10 @@ class RaggedCrossOp : public OpKernel { return absl::InvalidArgumentError( "tf.ragged.cross only supports inputs with rank=2."); } - if (ragged_splits_list[i].NumElements() == 0) { - return absl::InvalidArgumentError( - "Invalid RaggedTensor: Ragged splits must be non-empty."); - } - auto flat_row_splits = ragged_splits_list[i].flat(); - if (flat_row_splits(0) != 0) { - return absl::InvalidArgumentError( - "Invalid RaggedTensor: Ragged splits must start from 0."); - } + int64_t num_values = ragged_values_list[i].NumElements(); - if (flat_row_splits(flat_row_splits.size() - 1) != num_values) { - return absl::InvalidArgumentError( - "Invalid RaggedTensor: " - "Ragged splits must end with the number of values."); - } - for (int i = 1; i < flat_row_splits.size(); ++i) { - if (flat_row_splits(i - 1) > flat_row_splits(i)) { - return absl::InvalidArgumentError( - "Invalid RaggedTensor: " - "Ragged splits must be sorted in ascending order."); - } - } + TF_RETURN_IF_ERROR(RaggedTensorVerifySplits( + ragged_splits_list[i], true, num_values)); } for (int i = 0; i < num_sparse; ++i) { if (!TensorShapeUtils::IsMatrix(sparse_indices_list[i].shape()) || diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index f9d45627bc109c..153fd5a98fea1e 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/kernels/ragged_tensor_variant.h" +#include "tensorflow/core/kernels/ragged_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" @@ -191,34 +192,6 @@ class RaggedTensorToVariantOp : public OpKernel { // Validate nested_row_splits. for (int i = ragged_nested_splits_len - 1; i >= 0; --i) { - OP_REQUIRES(context, ragged_nested_splits_in[i].dims() == 1, - errors::InvalidArgument("Requires nested_row_splits[", i, "]", - " to be rank 1 but is rank ", - ragged_nested_splits_in[i].dims())); - OP_REQUIRES( - context, ragged_nested_splits_in[i].dim_size(0) >= 1, - errors::InvalidArgument("Requires nested_row_splits[", i, "]", - " has at least one splits, but is empty.")); - OP_REQUIRES(context, - ragged_nested_splits_in[i].flat()(0) == - static_cast(0), - errors::InvalidArgument( - "Requires the first element of nested_row_splits[", i, - "]", " to be 0 but is ", - ragged_nested_splits_in[i].flat()(0))); - - SPLIT_TYPE last_split = 0; - for (int j = 1; j < ragged_nested_splits_in[i].dim_size(0); j++) { - auto split = ragged_nested_splits_in[i].flat()(j); - OP_REQUIRES( - context, split >= last_split, - errors::InvalidArgument("Requires splits to be monotonically " - "increasing, but nested_row_splits[", - i, "][", j, "]=", split, - " is smaller than nested_row_splits[", i, - "][", j - 1, "]=", last_split)); - last_split = split; - } SPLIT_TYPE nvals; if (i == ragged_nested_splits_len - 1) { OP_REQUIRES(context, batched_ragged_input.values().dims() >= 1, @@ -230,12 +203,8 @@ class RaggedTensorToVariantOp : public OpKernel { nvals = ragged_nested_splits_in[i + 1].dim_size(0) - 1; } - OP_REQUIRES(context, last_split == nvals, - errors::InvalidArgument("Requires nested_row_splits[", i, - "][-1]=", last_split, - " to be equal with the number of " - "values in this dimension, which is ", - nvals, ".")); + OP_REQUIRES_OK(context, RaggedTensorVerifySplits( + ragged_nested_splits_in[i], true, nvals)); } for (int i = 0; i < ragged_nested_splits_len; i++) { @@ -290,6 +259,15 @@ class RaggedTensorToVariantGradientOp : public OpKernel { TensorShapeUtils::MakeShape(context->input(2).vec(), &dense_values_shape)); + // Validate row_splits. + // Note rank of the row_splits can be 0. Besides, the number of ragged + // values corresponding to the outermost splits are unknown when calculating + // the gradient so we don't check the last element of `row_splits` + if (row_splits.dims()) { + OP_REQUIRES_OK( + context, RaggedTensorVerifySplits(row_splits, false, 0)); + } + const auto& flat_variants = encoded_variant.flat(); // Get a Tensor containing the flat_values for each variant. diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc index adfe17667315c8..ac04580f4dec11 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc @@ -390,8 +390,8 @@ TEST_F(RaggedTensorToVariantKernelTest, true); EXPECT_THAT(RunOpKernel(), testing::StatusIs(error::INVALID_ARGUMENT, - "Requires the first element of " - "nested_row_splits[0] to be 0 but is 1")); + "Invalid ragged splits: first element of " + "ragged splits must be 0 but is 1")); } TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsIncreasingError) { @@ -400,9 +400,10 @@ TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsIncreasingError) { true); EXPECT_THAT(RunOpKernel(), testing::StatusIs(error::INVALID_ARGUMENT, - "Requires splits to be monotonically " - "increasing, but nested_row_splits[0][2]=-1 is " - "smaller than nested_row_splits[0][1]=2")); + "Invalid ragged splits: ragged splits must be " + "monotonically increasing, but " + "ragged_splits[2]=-1 is smaller than " + "row_splits[1]=2")); } TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsSizeMismatchError) { @@ -412,8 +413,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsSizeMismatchError) { EXPECT_THAT( RunOpKernel(), testing::StatusIs(error::INVALID_ARGUMENT, - "Requires nested_row_splits[0][-1]=3 to be equal with " - "the number of values in this dimension, which is 5.")); + "Invalid ragged splits: last element of ragged splits " + "must be the number of ragged values(5) but is 3")); } TEST_F(RaggedTensorToVariantKernelTest, @@ -425,8 +426,8 @@ TEST_F(RaggedTensorToVariantKernelTest, EXPECT_THAT( RunOpKernel(), testing::StatusIs(error::INVALID_ARGUMENT, - "Requires nested_row_splits[1][-1]=4 to be equal with " - "the number of values in this dimension, which is 5.")); + "Invalid ragged splits: last element of ragged splits " + "must be the number of ragged values(5) but is 4")); } TEST_F(RaggedTensorToVariantKernelTest, @@ -438,8 +439,8 @@ TEST_F(RaggedTensorToVariantKernelTest, EXPECT_THAT( RunOpKernel(), testing::StatusIs(error::INVALID_ARGUMENT, - "Requires nested_row_splits[0][-1]=2 to be equal with " - "the number of values in this dimension, which is 3.")); + "Invalid ragged splits: last element of ragged splits " + "must be the number of ragged values(3) but is 2")); } TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsEmptySplitsError) { @@ -448,8 +449,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsEmptySplitsError) { {0, 1, 2, 3, 4}, true); EXPECT_THAT(RunOpKernel(), testing::StatusIs(error::INVALID_ARGUMENT, - "Requires nested_row_splits[0] has at least " - "one splits, but is empty.")); + "Invalid ragged splits: ragged splits must " + "have at least one splits, but is empty")); } TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsScalarValueError) { @@ -462,5 +463,83 @@ TEST_F(RaggedTensorToVariantKernelTest, NestedRowSplitsScalarValueError) { "nested_row_splits is not empty, but is 0.")); } +TEST_F(RaggedTensorToVariantGradientKernelTest, RowSplitsMatch) { + // encoded_variant_grad= + // [ [1, 2, 3], + // [ ], + // [4, 5 ], + // [6 ]] + auto encoded_variant_grad_1 = + CreateVariantFromRagged({}, {3}, {1, 2, 3}); + auto encoded_variant_grad_2 = + CreateVariantFromRagged({}, {0}, {}); + auto encoded_variant_grad_3 = + CreateVariantFromRagged({}, {2}, {4, 5}); + auto encoded_variant_grad_4 = + CreateVariantFromRagged({}, {1}, {6}); + + BuildEncodeRaggedTensorGradientGraph( + {encoded_variant_grad_1, encoded_variant_grad_2, encoded_variant_grad_3, + encoded_variant_grad_4}, + {0, 3, 3, 5, 6}, {6}); + + TF_ASSERT_OK(RunOpKernel()); +} + +TEST_F(RaggedTensorToVariantGradientKernelTest, + RowSplitsFirstElementNotZeroError) { + // encoded_variant_grad= + // [ [1, 2, 3], + // [ ], + // [4, 5 ], + // [6 ]] + auto encoded_variant_grad_1 = + CreateVariantFromRagged({}, {3}, {1, 2, 3}); + auto encoded_variant_grad_2 = + CreateVariantFromRagged({}, {0}, {}); + auto encoded_variant_grad_3 = + CreateVariantFromRagged({}, {2}, {4, 5}); + auto encoded_variant_grad_4 = + CreateVariantFromRagged({}, {1}, {6}); + + BuildEncodeRaggedTensorGradientGraph( + {encoded_variant_grad_1, encoded_variant_grad_2, encoded_variant_grad_3, + encoded_variant_grad_4}, + {1, 3, 3, 5, 6}, {6}); + + EXPECT_THAT(RunOpKernel(), + testing::StatusIs(error::INVALID_ARGUMENT, + "Invalid ragged splits: first element of " + "ragged splits must be 0 but is 1")); +} + +TEST_F(RaggedTensorToVariantGradientKernelTest, RowSplitsIncreasingError) { + // encoded_variant_grad= + // [ [1, 2, 3], + // [ ], + // [4, 5 ], + // [6 ]] + auto encoded_variant_grad_1 = + CreateVariantFromRagged({}, {3}, {1, 2, 3}); + auto encoded_variant_grad_2 = + CreateVariantFromRagged({}, {0}, {}); + auto encoded_variant_grad_3 = + CreateVariantFromRagged({}, {2}, {4, 5}); + auto encoded_variant_grad_4 = + CreateVariantFromRagged({}, {1}, {6}); + + BuildEncodeRaggedTensorGradientGraph( + {encoded_variant_grad_1, encoded_variant_grad_2, encoded_variant_grad_3, + encoded_variant_grad_4}, + {0, 3, 2, 5, 6}, {6}); + + EXPECT_THAT(RunOpKernel(), + testing::StatusIs(error::INVALID_ARGUMENT, + "Invalid ragged splits: ragged splits must be " + "monotonically increasing, but " + "ragged_splits[2]=2 is smaller than " + "row_splits[1]=3")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h index 0b71a308b2c503..7dc63ac8fbf7f8 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include @@ -129,6 +130,60 @@ class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase { } }; +class RaggedTensorToVariantGradientKernelTest + : public ::tensorflow::OpsTestBase { + protected: + // Builds the tensorflow test graph for the RaggedTensorToVariantGradient op, + // and populates the `encoded_ragged_grad`, `row_splits` and + // `dense_values_shape` input with the given values. + template + void BuildEncodeRaggedTensorGradientGraph( + const std::vector& encoded_ragged_grad, + const std::vector& row_splits, + const std::vector& dense_values_shape) { + const auto values_dtype = DataTypeToEnum::v(); + const auto splits_dtype = DataTypeToEnum::v(); + + TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariantGradient") + .Input(FakeInput(DT_VARIANT)) // encoded_ragged_grad + .Input(FakeInput(splits_dtype)) // row_splits + .Input(FakeInput(DT_INT32)) // dense_values_shape + .Attr("Tvalues", values_dtype) + .Attr("Tsplits", splits_dtype) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + int64_t encoded_ragged_grad_size = encoded_ragged_grad.size(); + AddInputFromArray(TensorShape({encoded_ragged_grad_size}), + encoded_ragged_grad); + + int64_t splits_size = row_splits.size(); + AddInputFromArray(TensorShape({splits_size}), row_splits); + + int64_t dense_values_shape_size = dense_values_shape.size(); + AddInputFromArray(TensorShape({dense_values_shape_size}), + dense_values_shape); + } + + template + RaggedTensorVariant CreateVariantFromRagged( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values) { + RaggedTensorVariant encoded; + for (auto ragged_split : ragged_splits) { + int splits_size = ragged_split.size(); + Tensor splits(DataTypeToEnum::v(), + TensorShape({splits_size})); + test::FillValues(&splits, ragged_split); + encoded.append_splits(splits); + } + Tensor values(DataTypeToEnum::v(), ragged_values_shape); + test::FillValues(&values, ragged_values); + encoded.set_values(values); + return encoded; + } +}; } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_TO_VARIANT_OP_TEST_H_ diff --git a/tensorflow/core/kernels/ragged_utils.h b/tensorflow/core/kernels/ragged_utils.h new file mode 100644 index 00000000000000..f91f1da343993f --- /dev/null +++ b/tensorflow/core/kernels/ragged_utils.h @@ -0,0 +1,77 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_UTILS_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Utility functions for RaggedTensor + +// Verifies that the splits are valid for ragged tensor +template +Status RaggedTensorVerifySplits(const Tensor& ragged_splits, + bool check_last_element, + int64_t num_ragged_values) { + auto flat_ragged_splits = ragged_splits.flat(); + + if (ragged_splits.dims() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid ragged splits: ragged splits must be rank 1 but is rank ", + ragged_splits.dims())); + } + + if (ragged_splits.NumElements() < 1) { + return absl::InvalidArgumentError( + "Invalid ragged splits: ragged splits must have at least one splits, " + "but is empty"); + } + + if (flat_ragged_splits(0) != static_cast(0)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid ragged splits: first element of ragged splits " + " must be 0 but is ", + flat_ragged_splits(0))); + } + + SPLIT_TYPE last_split = 0; + for (int j = 1; j < ragged_splits.dim_size(0); j++) { + auto split = flat_ragged_splits(j); + if (split < last_split) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid ragged splits: ragged splits must be " + "monotonically increasing, but ragged_splits[", + j, "]=", split, " is smaller than row_splits[", j - 1, + "]=", last_split)); + } + last_split = split; + } + + if (check_last_element & last_split != num_ragged_values) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid ragged splits: last element of ragged splits must be ", + "the number of ragged values(", num_ragged_values, ") but is ", + last_split)); + } + + return absl::OkStatus(); +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_UTILS_H_ diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc index ac5a410a261ed9..ae9169df96ac3a 100644 --- a/tensorflow/core/kernels/roll_op.cc +++ b/tensorflow/core/kernels/roll_op.cc @@ -15,14 +15,19 @@ limitations under the License. #include "tensorflow/core/kernels/roll_op.h" +#include +#include + #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/register_types_traits.h" -#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" @@ -191,9 +196,9 @@ void DoRollWithMemcpy(const OpKernelContext* context, int64_t start, int64_t end) { // the number of indices over in the flattened tensor you need to skip in // order to make it over from one side of the isd to the other - const int64_t isd_range = std::max(dim_range[isd], 1); - // the distance along the flattend tensor to the next element in the isd - const int64_t isd_stride = isd_range / std::max(dim_size[isd], 1); + const int64_t isd_range = std::max(dim_range[isd], 1); + // the distance along the flattened tensor to the next element in the isd + const int64_t isd_stride = isd_range / std::max(dim_size[isd], 1); // start and end represent the i-th group currently so we will convert // them into numbers representing the i-th elements. @@ -295,9 +300,10 @@ void DoRollWithMemcpy(const OpKernelContext* context, // Shard auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); const int64_t ave_group_size = dim_range[isd] / 2; - const int total_work = 2 * num_elements / std::max(dim_range[isd], 1); + const int64_t total_work = + 2 * num_elements / std::max(dim_range[isd], 1); // 25000 - experimentally determined with float and bool types - const int cost_per_group = 25000 * sizeof(T) * ave_group_size; + const int64_t cost_per_group = 25000 * sizeof(T) * ave_group_size; Shard(worker_threads->num_threads, worker_threads->workers, total_work, cost_per_group, std::move(work)); } diff --git a/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc b/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc index d031461318df3e..eb58830bff17d2 100644 --- a/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc +++ b/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc @@ -74,7 +74,7 @@ TF_CALL_INTEGRAL_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER(Name("_TensorToHashBucketFast") \ @@ -86,6 +86,6 @@ TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_to_hash_bucket_op.h b/tensorflow/core/kernels/tensor_to_hash_bucket_op.h index 6c75b8cffccd10..cdf7dab23947d9 100644 --- a/tensorflow/core/kernels/tensor_to_hash_bucket_op.h +++ b/tensorflow/core/kernels/tensor_to_hash_bucket_op.h @@ -66,13 +66,13 @@ struct LaunchTensorToHashBucket { } }; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template struct LaunchTensorToHashBucket { void operator()(OpKernelContext* c, const int64_t num_buckets, const T* input, const int num_elems, int64_t* output); }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc b/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc index 8e79c7929f013c..8b6b0d48ecc461 100644 --- a/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc +++ b/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive @@ -119,4 +119,4 @@ TF_CALL_INTEGRAL_TYPES(REGISTER_FUNCTORS); #undef REGISTER_FUNCTORS } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index 152dbcd96a0e63..26162abc2f8f80 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -483,25 +483,16 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, bool ran_nonsegmented_version = false; if (num_rows == 1) { -#if GOOGLE_CUDA - constexpr bool is_supported = true; -#else - // GpuRadixSortDescending is not supported on ROCm for fp16/bf16. - constexpr bool is_supported = !std::is_same::value && - !std::is_same::value; -#endif - if constexpr (is_supported) { - // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because - // it only uses 1 SM per segment. Calling the un-segmented version is much - // faster in this case. - TF_RETURN_IF_ERROR( - GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input, - /*keys_out=*/sorted_values_ptr, - /*indices_in=*/input_indices_t.data(), - /*indices_out=*/sorted_indices_ptr, - /*num_bits=*/sizeof(T) * 8)); - ran_nonsegmented_version = true; - } + // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because + // it only uses 1 SM per segment. Calling the un-segmented version is much + // faster in this case. + TF_RETURN_IF_ERROR( + GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input, + /*keys_out=*/sorted_values_ptr, + /*indices_in=*/input_indices_t.data(), + /*indices_out=*/sorted_indices_ptr, + /*num_bits=*/sizeof(T) * 8)); + ran_nonsegmented_version = true; } if (!ran_nonsegmented_version) { auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending( diff --git a/tensorflow/core/lib/png/BUILD b/tensorflow/core/lib/png/BUILD index 46f167b5e60302..cdd3491276c9a3 100644 --- a/tensorflow/core/lib/png/BUILD +++ b/tensorflow/core/lib/png/BUILD @@ -15,6 +15,7 @@ cc_library( name = "png_io", srcs = ["png_io.cc"], hdrs = ["png_io.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core/platform:byte_order", "//tensorflow/core/platform:logging", @@ -22,6 +23,7 @@ cc_library( "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", "@com_google_absl//absl/base", + "@png", "@zlib", ], ) diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc index 2bfbe4470d94a8..f07861a0be6808 100644 --- a/tensorflow/core/lib/png/png_io.cc +++ b/tensorflow/core/lib/png/png_io.cc @@ -25,6 +25,7 @@ limitations under the License. // provokes a compile error. We instead let png.h include what is needed. #include "absl/base/casts.h" +#include "png.h" // from @png #include "tensorflow/core/lib/png/png_io.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/logging.h" @@ -77,7 +78,7 @@ static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes, void ErrorHandler(png_structp png_ptr, png_const_charp msg) { DecodeContext* const ctx = - absl::bit_cast(png_get_io_ptr(png_ptr)); + absl::bit_cast(png_get_error_ptr(png_ptr)); ctx->error_condition = true; // To prevent log spam, errors are logged as VLOG(1) instead of ERROR. VLOG(1) << "PNG error: " << msg; @@ -354,8 +355,9 @@ bool WriteImageToBuffer( png_string->resize(0); png_infop info_ptr = nullptr; - png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr, - ErrorHandler, WarningHandler); + DecodeContext decode_context; + png_structp png_ptr = png_create_write_struct( + PNG_LIBPNG_VER_STRING, &decode_context, ErrorHandler, WarningHandler); if (png_ptr == nullptr) return false; if (setjmp(png_jmpbuf(png_ptr))) { png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr); diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchMatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchMatMul.pbtxt index 747bd5ed03115c..9d7ac3ca8e2a33 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/BatchMatMul.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/BatchMatMul.pbtxt @@ -174,3 +174,62 @@ op { } } } +op { + name: "BatchMatMul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "adj_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "adj_y" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV2.pbtxt index aa446c0a492155..4769d8220f53e1 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV2.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV2.pbtxt @@ -139,3 +139,67 @@ op { } } } +op { + name: "BatchMatMulV2" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + type: DT_UINT32 + type: DT_UINT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "adj_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "adj_y" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV3.pbtxt index 332af934efbb23..1bcfdb937064ca 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV3.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/BatchMatMulV3.pbtxt @@ -164,3 +164,101 @@ op { } } } +op { + name: "BatchMatMulV3" + input_arg { + name: "x" + type_attr: "Ta" + } + input_arg { + name: "y" + type_attr: "Tb" + } + output_arg { + name: "output" + type_attr: "Tout" + } + attr { + name: "Ta" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "Tb" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "Tout" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "adj_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "adj_y" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt new file mode 100644 index 00000000000000..5fa2302622c9ac --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt @@ -0,0 +1,8 @@ +op { + name: "GlobalIterId" + output_arg { + name: "iter_id" + type: DT_INT64 + } + is_stateful: true +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/ListSnapshotChunksDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ListSnapshotChunksDataset.pbtxt new file mode 100644 index 00000000000000..be35470141fb28 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/ListSnapshotChunksDataset.pbtxt @@ -0,0 +1,44 @@ +op { + name: "ListSnapshotChunksDataset" + input_arg { + name: "snapshot_path" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + experimental_full_type { + type_id: TFT_DATASET + args { + type_id: TFT_FOR_EACH + args { + type_id: TFT_PRODUCT + } + args { + type_id: TFT_TENSOR + args { + type_id: TFT_VAR + s: "output_types" + } + } + args { + type_id: TFT_VAR + s: "output_types" + } + } + } + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/MatMul.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/MatMul.pbtxt index 369f763e9472c5..8f79fa11000f7f 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/MatMul.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/MatMul.pbtxt @@ -223,3 +223,66 @@ op { } } } +op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + type: DT_UINT32 + type: DT_UINT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "grad_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_b" + type: "bool" + default_value { + b: false + } + } +} diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 1203d36a8951c0..396e1720aaf2fd 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -1252,6 +1252,21 @@ REGISTER_OP("SnapshotNestedDatasetReader") "output_types")) .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ListSnapshotChunksDataset") + .Input("snapshot_path: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() + .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, + "output_types")) + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `snapshot_path` should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("SqlDataset") .Input("driver_name: string") .Input("data_source_name: string") diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index a5356ab9d38911..d54750253f32e3 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -126,6 +126,8 @@ REGISTER_OP("BatchMatMul") "complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulShape); REGISTER_OP("BatchMatMulV2") @@ -137,6 +139,8 @@ REGISTER_OP("BatchMatMulV2") "uint16, uint32, uint64, complex64, complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape); REGISTER_OP("BatchMatMulV3") @@ -154,6 +158,8 @@ REGISTER_OP("BatchMatMulV3") "complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape); #ifdef INTEL_MKL @@ -164,6 +170,8 @@ REGISTER_OP("_MklBatchMatMul") .Attr("T: {bfloat16, float}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulShape); REGISTER_OP("_MklBatchMatMulV2") @@ -173,6 +181,8 @@ REGISTER_OP("_MklBatchMatMulV2") .Attr("T: {bfloat16, float}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .Attr("grad_x: bool = false") + .Attr("grad_y: bool = false") .SetShapeFn(shape_inference::BatchMatMulV2Shape); #endif // INTEL_MKL @@ -953,6 +963,8 @@ REGISTER_OP("MatMul") .Attr( "T: {bfloat16, half, float, double, int32, int64, uint8, " "uint16, uint32, uint64, complex64, complex128}") + .Attr("grad_a: bool = false") + .Attr("grad_b: bool = false") .SetShapeFn(shape_inference::MatMulShape); #ifdef INTEL_MKL @@ -963,6 +975,8 @@ REGISTER_OP("_MklMatMul") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") .Attr("T: {bfloat16, float}") + .Attr("grad_a: bool = false") + .Attr("grad_b: bool = false") .SetShapeFn(shape_inference::MatMulShape); #endif // INTEL_MKL diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 2bcdc8b109329d..95a20f13f87522 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4204,6 +4204,20 @@ op { b: false } } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } } op { name: "BatchMatMulV2" @@ -4254,6 +4268,20 @@ op { b: false } } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } } op { name: "BatchMatMulV3" @@ -4338,6 +4366,20 @@ op { b: false } } + attr { + name: "grad_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_y" + type: "bool" + default_value { + b: false + } + } } op { name: "BatchMatrixBandPart" @@ -22034,6 +22076,14 @@ op { } is_stateful: true } +op { + name: "GlobalIterId" + output_arg { + name: "iter_id" + type: DT_INT64 + } + is_stateful: true +} op { name: "Greater" input_arg { @@ -25210,6 +25260,50 @@ op { } } } +op { + name: "ListSnapshotChunksDataset" + input_arg { + name: "snapshot_path" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + experimental_full_type { + type_id: TFT_DATASET + args { + type_id: TFT_FOR_EACH + args { + type_id: TFT_PRODUCT + } + args { + type_id: TFT_TENSOR + args { + type_id: TFT_VAR + s: "output_types" + } + } + args { + type_id: TFT_VAR + s: "output_types" + } + } + } + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "LoadAllTPUEmbeddingParameters" input_arg { @@ -27092,6 +27186,20 @@ op { } } } + attr { + name: "grad_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "grad_b" + type: "bool" + default_value { + b: false + } + } } op { name: "MatchingFiles" diff --git a/tensorflow/core/platform/build_config.default.bzl b/tensorflow/core/platform/build_config.default.bzl index 24421c6d6e8b87..80c7d25ad1dd9e 100644 --- a/tensorflow/core/platform/build_config.default.bzl +++ b/tensorflow/core/platform/build_config.default.bzl @@ -1,24 +1,23 @@ """OSS versions of Bazel macros that can't be migrated to TSL.""" +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@local_xla//xla:xla.bzl", + _xla_clean_dep = "clean_dep", +) load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", ) load( - "@local_xla//xla:xla.bzl", - _xla_clean_dep = "clean_dep", + "//third_party/mkl:build_defs.bzl", + "if_mkl_ml", ) load( "@local_tsl//tsl:tsl.bzl", "if_libtpu", _tsl_clean_dep = "clean_dep", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl_ml", -) def tf_tpu_dependencies(): return if_libtpu(["//tensorflow/core/tpu/kernels"]) @@ -34,11 +33,7 @@ def tf_additional_binary_deps(): # core. str(Label("//tensorflow/core/kernels:lookup_util")), str(Label("//tensorflow/core/util/tensor_bundle")), - ] + if_cuda( - [ - str(Label("@local_xla//xla/stream_executor:cuda_platform")), - ], - ) + if_rocm( + ] + if_rocm( [ str(Label("@local_xla//xla/stream_executor:rocm_platform")), str(Label("@local_xla//xla/stream_executor/rocm:rocm_rpath")), diff --git a/tensorflow/core/platform/host_info.h b/tensorflow/core/platform/host_info.h index 89d495d6e41229..caab7ae380b31b 100644 --- a/tensorflow/core/platform/host_info.h +++ b/tensorflow/core/platform/host_info.h @@ -22,6 +22,7 @@ limitations under the License. namespace tensorflow { namespace port { using tsl::port::Hostname; +using tsl::port::IOStatistics; using tsl::port::JobName; using tsl::port::JobUid; } // namespace port diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 3590422f92d5b5..27d9d649c84cb3 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -1015,6 +1015,7 @@ cc_library( srcs = ["xspace_to_dcn_slack_analysis.cc"], hdrs = ["xspace_to_dcn_slack_analysis.h"], deps = [ + "//tensorflow/core/profiler/protobuf:dcn_collective_info_proto_cc", "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", "//tensorflow/core/profiler/protobuf:topology_proto_cc", "//tensorflow/core/profiler/utils:hlo_module_utils", diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc index 2111ea4f56ac6e..124d4096518f95 100644 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ b/tensorflow/core/profiler/convert/op_profile_builder.cc @@ -185,9 +185,12 @@ void PopulateOpMetricsNode( // https://github.com/tensorflow/profiler/blob/master/frontend/app/common/utils/utils.ts metrics->set_raw_time(op_metrics.time_ps()); metrics->set_raw_flops(op_metrics.flops()); + metrics->set_occurrences(op_metrics.occurrences()); + metrics->set_avg_time_ps( + SafeDivide(op_metrics.time_ps(), op_metrics.occurrences())); // Hack to approximate utilization for INT8/4 convolution HLOs: - // Since MXU BW is 2x/4x for INT8/4, multiply peak BW by the factor detemrined + // Since MXU BW is 2x/4x for INT8/4, multiply peak BW by the factor determined // by the computation size if (GetComputationSize(*node) == 8) { peak_gigaflops_per_second_per_core *= 2; diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.cc b/tensorflow/core/profiler/convert/process_megascale_dcn.cc index dc740ffd9090c0..947c5e54a19568 100644 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.cc +++ b/tensorflow/core/profiler/convert/process_megascale_dcn.cc @@ -47,6 +47,8 @@ void ProcessMegascaleDcn(XSpace* space) { for (XPlane* device_xplane : device_xplanes) { dcn_events_processor.AddTpuCollectiveDcnTrafficToXPlane(device_xplane); } + + SortXSpace(space); } } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc index 53f94371d1a96d..cfb4f2ec20cf4b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc @@ -207,6 +207,7 @@ void ConvertXPlaneToTraceEventsContainer(uint64_t device_id, } plane.ForEachLine([&](const XLineVisitor& line) { + if (line.DisplayName() == tsl::profiler::kXlaAsyncOpLineName) return; if (line.NumEvents() == 0) return; // Capture a copy of XLineVisitor because it will go out of scope. uint32_t device_id = resource_grouper->GetDeviceId(line.DisplayId()); @@ -241,7 +242,7 @@ void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, for (const XPlane* custom_plane : FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) { ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kCustomPlaneDeviceId + custom_plane->id(), hostname, + tsl::profiler::kFirstCustomPlaneDeviceId + custom_plane->id(), hostname, *custom_plane, container); } } diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc index 95b6342c525133..f85d0f92904c80 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/side_effect_util.h" #include "xla/xla_data.pb.h" +#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" #include "tensorflow/core/profiler/protobuf/topology.pb.h" #include "tensorflow/core/profiler/utils/hlo_module_utils.h" @@ -74,7 +75,7 @@ using xla::HloOpcode; // TODO: Identify mechanism to maintain consistency between producer and // consumer here. const char kHostEventRegex[] = { - "device_[0-9][0-9][0-9]([0-9][0-9][0-9])_gid_(.*)"}; + "device_[0-9]+([0-9][0-9][0-9][0-9][0-9])_gid_(.*)"}; std::optional GetAttributeFromInstr( const xla::HloInstruction* instr, std::string_view attribute) { @@ -110,6 +111,21 @@ std::string HostCollectiveKey(int index_on_host, return absl::StrCat(index_on_host, "_", rendezvous_name); } +DcnCollectiveInfoProto GetDcnCollectiveInfoProto(const XEventVisitor& xevent) { + DcnCollectiveInfoProto dcn_collective_info; + xevent.Metadata().ForEachStat([&](const XStatVisitor& xstat) { + if (static_cast(*xstat.Type()) == StatType::kDcnCollectiveInfo) { + absl::string_view byte_value = xstat.BytesValue(); + if (!dcn_collective_info.ParseFromArray(byte_value.data(), + byte_value.size())) { + LOG(WARNING) << "Could not parse DcnCollectiveInfoProto from metadata."; + } + } + }); + + return dcn_collective_info; +} + } // namespace namespace dcn_analysis_internal { @@ -203,6 +219,55 @@ void DcnTracker::UpdateActiveOps(uint64_t duration) { } } +int DcnTracker::GetReplicaGroupSize(const std::string& rendezvous_name, + const XEventVisitor& visitor) { + if (rendezvous_to_replica_group_size_map_.contains(rendezvous_name)) { + return rendezvous_to_replica_group_size_map_[rendezvous_name]; + } + + DcnCollectiveInfoProto dcn_collective_info = + GetDcnCollectiveInfoProto(visitor); + + if (dcn_collective_info.one_to_one_groups_size() != 0) { + // OneToOneGroup has a source and a destination, which is one replica group + rendezvous_to_replica_group_size_map_[rendezvous_name] = 1; + } else if (dcn_collective_info.endpoint_groups_size() != 0) { + rendezvous_to_replica_group_size_map_[rendezvous_name] = + dcn_collective_info.endpoint_groups(0).endpoints().size(); + } else { + rendezvous_to_replica_group_size_map_[rendezvous_name] = 0; + } + + return rendezvous_to_replica_group_size_map_[rendezvous_name]; +} + +uint64_t DcnTracker::ComputeTransmittedDataSize( + const int64_t buffer_size, const int group_size, + const std::string& transfer_type) { + uint64_t transmitted_bytes = 0; + if (group_size == 0) { + LOG(ERROR) << "Replica group size is 0."; + return transmitted_bytes; + } + + if (transfer_type == "ONE_TO_ONE") { + transmitted_bytes = group_size * buffer_size; + } else if (transfer_type == "ALL_GATHER") { + transmitted_bytes = (group_size - 1) * buffer_size; + } else if (transfer_type == "ALL_REDUCE") { + // Since the reduced buffer now has to be sent back to the replicas, + // the total bytes transmitted over the network is 2x the shape of the op. + transmitted_bytes = + 2 * SafeDivide(group_size - 1, group_size) * buffer_size; + } else if (transfer_type == "ALL_TO_ALL" || + transfer_type == "REDUCE_SCATTER") { + transmitted_bytes = SafeDivide(group_size - 1, group_size) * buffer_size; + } else { + LOG(ERROR) << "Unsupported transfer type: " << transfer_type; + } + return transmitted_bytes; +} + void DcnTracker::VisitOp(const InstrMetadata& instr, const XEventVisitor& visitor) { std::string rendezvous_name; @@ -233,6 +298,8 @@ void DcnTracker::VisitOp(const InstrMetadata& instr, opState.send_op_name = visitor.DisplayName(); opState.send.set_duration_ps(visitor.DurationPs()); opState.send.set_start_time_ps(visitor.TimestampPs()); + opState.replica_group_size = + GetReplicaGroupSize(rendezvous_name, visitor); break; case HloOpcode::kRecv: opState.recv.set_duration_ps(visitor.DurationPs()); @@ -255,16 +322,8 @@ void DcnTracker::VisitOp(const InstrMetadata& instr, analysis->set_slack_us(NanoToMicro(visitor.TimestampNs() - opState.start_time - opState.overlapping_duration)); - // TODO(b/294584919): The current transmitted bytes measures the - // buffer size at the recv-done. This could include bytes that were not - // received over the network. Fix the calculation based on the number of - // replica groups. - // In case of ALL_REDUCE, Since the reduced buffer now - // has to be sent back to the replicas, the total bytes transmitted over - // the network is 2x the shape of the op. - analysis->set_bytes_transmitted_over_network( - analysis->transfer_type() == "ALL_REDUCE" ? 2 * instr.size - : instr.size); + analysis->set_bytes_transmitted_over_network(ComputeTransmittedDataSize( + instr.size, opState.replica_group_size, opState.transfer_type)); analysis->set_stall_duration_us(NanoToMicro(opState.stall_duration_ns)); analysis->set_recv_op_name(std::string(visitor.DisplayName())); analysis->set_send_op_name(opState.send_op_name); diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h index 3f38da65460346..daac70f634abca 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" #include "tensorflow/core/profiler/protobuf/topology.pb.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" @@ -53,6 +54,7 @@ struct DcnOpState { std::string transfer_type; uint64_t stall_duration_ns = 0; std::string send_op_name; + int replica_group_size = 0; OpInstance send; OpInstance send_done; @@ -125,6 +127,7 @@ class DcnTracker { absl::flat_hash_map global_chip_id_to_local_index_map_; absl::flat_hash_map> hlo_module_cache_; + absl::flat_hash_map rendezvous_to_replica_group_size_map_; bool is_megacore_ = true; absl::StatusOr GetInstrMetadataFromHloModule( @@ -140,6 +143,14 @@ class DcnTracker { // GetLocalIndex when available, else return the global_device_id itself. int GetLocalIndex(int dcn_device_id); + + // Get number of replica group + int GetReplicaGroupSize(const std::string& rendezvous_name, + const tsl::profiler::XEventVisitor& visitor); + + // Compute data transmitted size based on number of replica groups + uint64_t ComputeTransmittedDataSize(int64_t buffer_size, int group_size, + const std::string& transfer_type); }; } // namespace dcn_analysis_internal diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 0e0ffa72b6ce71..c521c9ad89d392 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -270,3 +270,10 @@ tf_proto_library( cc_api_version = 2, visibility = [":friends"], ) + +tf_proto_library( + name = "dcn_collective_info_proto", + srcs = ["dcn_collective_info.proto"], + cc_api_version = 2, + visibility = [":friends"], +) diff --git a/tensorflow/core/profiler/protobuf/dcn_collective_info.proto b/tensorflow/core/profiler/protobuf/dcn_collective_info.proto new file mode 100644 index 00000000000000..5359a3dd54c1c6 --- /dev/null +++ b/tensorflow/core/profiler/protobuf/dcn_collective_info.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package tensorflow.profiler; + +// This proto is based on MegaScaleInfoProto and should be consistent with it. +message DcnCollectiveInfoProto { + enum TransferType { + UNKNOWN_TRANSFER_TYPE = 0; + + // XLA AllToAll transfer. + // Needs `endpoint_groups`. + ALL_TO_ALL = 1; + + // Peer-To-Peer DCN transfer from source to one destination. + // Needs one_to_one_groups. + ONE_TO_ONE = 2; + + // XLA reduce-scatter transfer. + // Needs `endpoint_groups`. + REDUCE_SCATTER = 3; + + // XLA AllGather transfer. + // Needs `endpoint_groups`. + ALL_GATHER = 4; + + // XLA all-reduce transfer. + // Needs `endpoint_groups`. + ALL_REDUCE = 5; + } + + message Endpoint { + int32 slice_id = 1; + int32 device_id = 2; + } + + message EndpointGroup { + repeated Endpoint endpoints = 1; + } + + message OneToOneGroup { + Endpoint source = 1; + Endpoint destination = 2; + } + + // The type of DCN transfer. + TransferType transfer_type = 1; + + // Groups of endpoints (in the form of slice id and device id) involved in + // `ALL_TO_ALL`, `REDUCE_SCATTER`, `ALL_REDUCE` and `ALL_GATHER` transfer. + repeated EndpointGroup endpoint_groups = 2; + + // Groups of endpoints (in the form of slice id and device id) involved in + // `ONE_TO_ONE` transfer. + repeated OneToOneGroup one_to_one_groups = 3; +} diff --git a/tensorflow/core/profiler/protobuf/op_profile.proto b/tensorflow/core/profiler/protobuf/op_profile.proto index 9c29d1777eb38a..14ce2d203fb16a 100644 --- a/tensorflow/core/profiler/protobuf/op_profile.proto +++ b/tensorflow/core/profiler/protobuf/op_profile.proto @@ -82,6 +82,10 @@ message Metrics { // Total bytes accessed for each memory type. // Index into array using MemBwType enum. repeated double raw_bytes_accessed_array = 15; + // Number of executions. + uint32 occurrences = 16; + // Average "accumlated" time in picoseconds that the operation took. + double avg_time_ps = 17; reserved 1, 3, 4, 13, 14; } diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc index af957c54843ec7..0b32f5712edba5 100644 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc +++ b/tensorflow/core/profiler/utils/tfstreamz_utils.cc @@ -112,6 +112,9 @@ Status SerializeToXPlane(const std::vector& snapshots, xevent.AddStatValue(*metadata, *xplane.GetOrCreateStatMetadata( point->string_value)); break; + case monitoring::ValueType::kDouble: + xevent.AddStatValue(*metadata, point->double_value); + break; case monitoring::ValueType::kHistogram: xevent.AddStatValue(*metadata, point->histogram_value); break; diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index de23bad54d98de..f436e57dbf7995 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -673,6 +673,20 @@ message ConfigProto { // Whether runtime execution uses TFRT. bool use_tfrt = 18; + // If true, use Pathways with TFRT API for multi host support. + bool enable_multi_host = 27; + + // Port for the Pathways server. Ignored if enable_multi_host=false. + int32 backend_server_port = 28; + + // If true, TFRT will use TPU specific compiler passes and perform TPU + // specific initialization. + bool target_tpu = 29; + + // If true, TFRT will use GPU specific compiler passes and perform GPU + // specific initialization. + bool target_gpu = 30; + // The field "coordination_service was previously specified as a string; // this has been replaced with a message below. reserved 19; @@ -711,7 +725,7 @@ message ConfigProto { reserved 25; - // Next: 27 + // Next: 31 } Experimental experimental = 16; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index a36d3b961d93a9..63f8f6476f1696 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1672 // Updated: 2023/11/6 +#define TF_GRAPH_DEF_VERSION 1710 // Updated: 2023/12/14 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/BUILD b/tensorflow/core/runtime_fallback/BUILD index 28bd4565cebecb..da33d97329ad0b 100644 --- a/tensorflow/core/runtime_fallback/BUILD +++ b/tensorflow/core/runtime_fallback/BUILD @@ -52,10 +52,7 @@ tf_cc_binary( "//conditions:default": [ "//tensorflow/core:all_kernels", ], - }) + if_cuda([ - "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_gpu_alwayslink", - "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_gpu_alwayslink", - ]), + }) + if_cuda([]), ) cc_library( diff --git a/tensorflow/core/runtime_fallback/test/BUILD b/tensorflow/core/runtime_fallback/test/BUILD index d8fd4850f1357f..e210a8296f3b85 100644 --- a/tensorflow/core/runtime_fallback/test/BUILD +++ b/tensorflow/core/runtime_fallback/test/BUILD @@ -1,4 +1,3 @@ -load("@tf_runtime//tools:mlir_to_bef.bzl", "mlir_to_bef") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_test", "tf_cc_test") # copybara:uncomment load("//third_party/tf_runtime_google/cpp_tests:gen_tests.bzl", "tfrt_cc_test_and_strict_benchmark") @@ -17,21 +16,6 @@ package_group( ], ) -mlir_to_bef( - name = "testdata/batch_function_fallback.mlir", - tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", -) - -mlir_to_bef( - name = "testdata/create_op.mlir", - tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", -) - -mlir_to_bef( - name = "testdata/custom_thread_pool.mlir", - tfrt_translate = "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_translate", -) - cc_library( name = "forwarding_test_kernels", srcs = ["forwarding_test_kernels.cc"], @@ -142,43 +126,6 @@ cc_library( # ], # ) # -# # C++ benchmarks for batch function runtime fallback. -# tfrt_cc_test_and_strict_benchmark( -# name = "batch_function_fallback_benchmark", -# srcs = ["batch_function_fallback_benchmark_test.cc"], -# data = ["testdata/batch_function_fallback.mlir.bef"], -# enable_xprof = True, -# includes = ["third_party/tf_runtime/include"], -# owners = ["tf-runtime-testing"], -# tags = [ -# "need_main", -# "no_gpu", -# ], -# deps = [ -# "//base", -# "//devtools/build/runtime:get_runfiles_dir", -# "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", -# "//tensorflow/core/platform:env", -# "//tensorflow/core/platform:resource_loader", -# "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", -# "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor", -# "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", -# "//tensorflow/core/runtime_fallback/util:fallback_test_util", -# "//tensorflow/core/runtime_fallback/util:tensor_util", -# "//tensorflow/core/tfrt/utils:fallback_tensor", -# "@eigen_archive//:eigen3", -# "@tf_runtime//:bef", -# "@tf_runtime//:befexecutor", -# "@tf_runtime//:core_runtime_alwayslink", -# "@tf_runtime//:hostcontext_alwayslink", -# "@tf_runtime//:mlirtobef", -# "@tf_runtime//:support", -# "@tf_runtime//:tensor", -# "@tf_runtime//backends/cpu:core_runtime_alwayslink", -# "@tf_runtime//backends/cpu:test_ops_alwayslink", -# ], -# ) -# # # C++ tests and benchmarks for runtime fallback. # tfrt_cc_test_and_strict_benchmark( # name = "c_api_tfrt", @@ -214,10 +161,10 @@ cc_library( # srcs = ["runtime_fallback_kernels_test.cc"], # deps = [ # ":coreruntime_driver", -# "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", # "@com_google_googletest//:gtest", # "@com_google_googletest//:gtest_main", # "@llvm-project//llvm:Support", +# "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", # "@tf_runtime//:core_runtime", # "@tf_runtime//backends/cpu:core_runtime_alwayslink", # ] + select({ @@ -241,11 +188,11 @@ cc_library( # includes = ["third_party/tf_runtime/include"], # deps = [ # ":coreruntime_driver", +# "@com_google_googletest//:gtest", # "//tensorflow/core/platform:test_benchmark", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor", # "@local_tsl//tsl/platform/default/build_config:test_main", -# "@com_google_googletest//:gtest", # "@tf_runtime//:core_runtime_alwayslink", # "@tf_runtime//:hostcontext", # "@tf_runtime//:tensor", @@ -266,11 +213,11 @@ cc_library( # srcs = ["kernel_fallback_compat_request_state_test.cc"], # includes = ["third_party/tf_runtime/include"], # deps = [ +# "@com_google_googletest//:gtest", # "//tensorflow/core/framework:tensor_testutil", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", # "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler", # "@local_tsl//tsl/platform/default/build_config:test_main", -# "@com_google_googletest//:gtest", # "@tf_runtime//:core_runtime_alwayslink", # ], # ) @@ -323,32 +270,3 @@ cc_library( }), alwayslink = 1, ) - -tf_cc_shared_test( - name = "kernel_fallback_compat_test", - srcs = ["kernel_fallback_compat_test.cc"], - data = [ - "testdata/create_op.mlir.bef", - "testdata/custom_thread_pool.mlir.bef", - ], - tags = ["no_oss"], - deps = [ - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", - "//tensorflow/core:all_kernels", - "//tensorflow/core:lib", - "//tensorflow/core/platform:resource_loader", - "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", - "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", - "//tensorflow/core/runtime_fallback/util:fallback_test_util", - "//tensorflow/core/tfrt/fallback:op_kernel_runner", - "//tensorflow/core/tfrt/runtime", - "//tensorflow/core/tfrt/utils:thread_pool", - "@com_google_googletest//:gtest_main", - "@tf_runtime//:bef", - "@tf_runtime//:befexecutor", - "@tf_runtime//:core_runtime", - "@tf_runtime//:hostcontext", - "@tf_runtime//:init_tfrt_dialects", - "@tf_runtime//:tracing", - ], -) diff --git a/tensorflow/core/runtime_fallback/util/BUILD b/tensorflow/core/runtime_fallback/util/BUILD index 60a8b009c2d9a2..92db3499ec3977 100644 --- a/tensorflow/core/runtime_fallback/util/BUILD +++ b/tensorflow/core/runtime_fallback/util/BUILD @@ -10,6 +10,7 @@ package_group( name = "internal", packages = [ "//learning/brain/experimental/tfrt/native_lowering/kernels/...", + "//tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/...", "//tensorflow/core/runtime_fallback/...", "//tensorflow/core/tfrt/utils/...", ], diff --git a/tensorflow/core/tfrt/BUILD b/tensorflow/core/tfrt/BUILD index ecf03260bba471..790641ff63b747 100644 --- a/tensorflow/core/tfrt/BUILD +++ b/tensorflow/core/tfrt/BUILD @@ -9,3 +9,13 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "ifrt_program_ops_op_lib", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/tfrt/kernels:ifrt_program_ops", + "//tensorflow/core/tfrt/ops:ifrt_program_ops_op_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD index f6928dc560b5f6..469bbff25d6f04 100644 --- a/tensorflow/core/tfrt/common/BUILD +++ b/tensorflow/core/tfrt/common/BUILD @@ -20,6 +20,7 @@ package_group( # copybara:uncomment "//learning/brain/google/xla/...", # copybara:uncomment "//learning/brain/tfrc/...", # copybara:uncomment "//learning/brain/tfrt/...", + # copybara:uncomment "//learning/serving/model_servers/...", # copybara:uncomment "//platforms/xla/megascale/tensorflow/...", "//tensorflow/c/...", "//tensorflow/compiler/jit/...", diff --git a/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc b/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc index 5369961b4a87da..7e383c57b6311b 100644 --- a/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc +++ b/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc @@ -27,12 +27,13 @@ namespace xla { StatusOr> GetGpuClient( const PjrtClientFactoryOptions& option) { + xla::GpuClientOptions gpu_client_options; + gpu_client_options.node_id = option.gpu_options.node_id; + gpu_client_options.num_nodes = 1; + gpu_client_options.allowed_devices = option.gpu_options.allowed_devices; + gpu_client_options.platform_name = option.gpu_options.platform_name; TF_ASSIGN_OR_RETURN(std::unique_ptr client, - xla::GetStreamExecutorGpuClient( - option.gpu_options.asynchronous, - /*allocator_config=*/{}, option.gpu_options.node_id, - /*num_nodes=*/1, option.gpu_options.allowed_devices, - option.gpu_options.platform_name)); + xla::GetStreamExecutorGpuClient(gpu_client_options)); return std::move(client); } diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 38653fe09f1c0a..d99f3519d05206 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -205,6 +205,10 @@ tf_proto_library( name = "test_config_proto", testonly = True, srcs = ["test_config.proto"], + visibility = if_google( + [":friends"], + ["//visibility:public"], + ), ) tf_cc_test( diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index 092532a9a9df56..a48a74128e6212 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ @@ -17,16 +19,19 @@ cc_library( srcs = ["ifrt_serving_executable.cc"], hdrs = ["ifrt_serving_executable.h"], deps = [ - "//tensorflow/compiler/mlir/tfrt:tf2hlo", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:tf2hlo", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/concurrency:ref_count", @@ -66,3 +71,94 @@ cc_library( "@local_xla//xla/python/ifrt", ], ) + +cc_library( + name = "sharding_utils", + srcs = [ + "sharding_utils.cc", + ], + hdrs = [ + "sharding_utils.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:statusor", + "//tensorflow/core/tpu/kernels:sharding_utils", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:executable_run_options", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/pjrt_ifrt", + ], +) + +tf_cc_test( + name = "sharding_utils_test", + srcs = ["sharding_utils_test.cc"], + tags = ["no_oss"], + deps = [ + ":sharding_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core/framework:tensor_matcher", + "//tensorflow/core/framework:tensor_testutil", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/ifrt/ir", + "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + "@local_xla//xla/python/pjrt_ifrt:xla_ifrt", + ], +) + +tf_cc_test( + name = "ifrt_serving_executable_test", + srcs = [ + "ifrt_serving_executable_test.cc", + ], + data = [ + "//tensorflow/core/tfrt/ifrt/testdata", + ], + tags = ["no_oss"], + deps = [ + ":ifrt_serving_executable", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + ], +) diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index a03958bcbd644d..649e813a511ddd 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -1,4 +1,3 @@ - /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +24,9 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" @@ -84,24 +86,70 @@ IfrtServingExecutable::ConvertTensorToArray(const tensorflow::Tensor& tensor) { return single_array; } -absl::StatusOr> IfrtServingExecutable::Execute( +xla::ifrt::Future>> +IfrtServingExecutable::LookUpOrCreateExecutable( absl::Span inputs) { - // TODO(b/304839793): Build cache based on tensorshape etc - if (!ifrt_executable_) { - LOG(INFO) << "Cache missed. Building executable"; - - TF_ASSIGN_OR_RETURN(auto mlir_hlo_module, - CompileTfToHlo(*module_, inputs, signature_name(), - ifrt_client_->GetDefaultCompiler(), - shape_representation_fn_)); - - TF_ASSIGN_OR_RETURN( - ifrt_executable_, - ifrt_client_->GetDefaultCompiler()->Compile( - std::make_unique(mlir_hlo_module.get()), - std::make_unique())); + std::vector input_shapes; + for (const auto& tensor : inputs) { + input_shapes.push_back(tensor.shape()); + } + Key key(input_shapes); + + xla::ifrt::Promise< + absl::StatusOr>> + promise; + xla::ifrt::Future< + absl::StatusOr>> + future; + + { + absl::MutexLock lock(&mutex_); + + const auto it = ifrt_executables_.find(key); + if (it != ifrt_executables_.end()) { + return it->second; + } + + // Only create promise and future when cache missed. + promise = xla::ifrt::Future>>::CreatePromise(); + future = xla::ifrt::Future< + absl::StatusOr>>(promise); + + ifrt_executables_.emplace(key, future); + } + + LOG(INFO) << "Cache missed. Building executable"; + + absl::StatusOr> mlir_hlo_module = + CompileTfToHlo(*module_, inputs, signature_name(), + ifrt_client_->GetDefaultCompiler(), + shape_representation_fn_); + if (!mlir_hlo_module.ok()) { + promise.Set(mlir_hlo_module.status()); + return future; + } + + absl::StatusOr> ifrt_executable = + ifrt_client_->GetDefaultCompiler()->Compile( + std::make_unique(mlir_hlo_module->get()), + std::make_unique()); + if (!ifrt_executable.ok()) { + promise.Set(ifrt_executable.status()); + return future; } + promise.Set(std::shared_ptr( + std::move(*ifrt_executable))); + return future; +} + +absl::StatusOr> IfrtServingExecutable::Execute( + absl::Span inputs) { + TF_ASSIGN_OR_RETURN( + std::shared_ptr ifrt_executable, + LookUpOrCreateExecutable(inputs).Await()); + std::vector> args; args.reserve(inputs.size()); for (auto& tensor : inputs) { @@ -110,7 +158,7 @@ absl::StatusOr> IfrtServingExecutable::Execute( } TF_ASSIGN_OR_RETURN(auto execution_result, - ifrt_executable_->Execute( + ifrt_executable->Execute( absl::MakeSpan(args), /*options=*/{.untuple_result = true}, std::nullopt)); diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index 64ce6580f8cfab..9b1d86cbcbbfd7 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -21,9 +21,12 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -32,6 +35,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tsl/concurrency/ref_count.h" @@ -65,7 +69,30 @@ class IfrtServingExecutable { absl::StatusOr> Execute( absl::Span inputs); + int num_executables() const { + absl::MutexLock lock(&mutex_); + return ifrt_executables_.size(); + } + private: + // In memory cache key. + struct Key { + std::vector input_shapes; + template + friend H AbslHashValue(H h, const Key& key) { + for (const auto& shape : key.input_shapes) { + for (auto size : shape.dim_sizes()) { + h = H::combine(std::move(h), size); + } + } + return h; + } + + friend bool operator==(const Key& x, const Key& y) { + return x.input_shapes == y.input_shapes; + } + }; + std::string model_name_; std::string signature_name_; @@ -76,10 +103,17 @@ class IfrtServingExecutable { tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; - std::unique_ptr ifrt_executable_; + mutable absl::Mutex mutex_; + absl::flat_hash_map>>> + ifrt_executables_ ABSL_GUARDED_BY(mutex_); absl::StatusOr> ConvertTensorToArray( const tensorflow::Tensor& tensor); + + xla::ifrt::Future< + absl::StatusOr>> + LookUpOrCreateExecutable(absl::Span inputs); }; } // namespace ifrt_serving diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc new file mode 100644 index 00000000000000..a2de6e9a68e16e --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" + +#include +#include +#include +#include +#include + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/platform/test.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +TEST(IfrtServingExecutableTest, Basic) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/core/tfrt/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/executable.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + IfrtServingExecutable executable("test", "main", std::move(mlir_module), + client, + tensorflow::IdentityShapeRepresentationFn()); + + tensorflow::Tensor x(tensorflow::DT_INT32, tensorflow::TensorShape({1, 3})); + tensorflow::Tensor y(tensorflow::DT_INT32, tensorflow::TensorShape({3, 1})); + for (int i = 0; i < 3; ++i) { + x.flat()(i) = i + 1; + y.flat()(i) = i + 1; + } + + std::vector inputs{x, y}; + TF_ASSERT_OK_AND_ASSIGN(auto result, + executable.Execute(absl::MakeSpan(inputs))); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].dtype(), tensorflow::DT_INT32); + ASSERT_EQ(result[0].shape(), tensorflow::TensorShape({1, 1})); + ASSERT_EQ(result[0].flat()(0), 14); +} + +TEST(IfrtServingExecutableTest, MultipleShapes) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/core/tfrt/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/executable.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + IfrtServingExecutable executable("test", "main", std::move(mlir_module), + client, + tensorflow::IdentityShapeRepresentationFn()); + + constexpr int kDim1 = 3; + tensorflow::Tensor x1(tensorflow::DT_INT32, + tensorflow::TensorShape({1, kDim1})); + tensorflow::Tensor y1(tensorflow::DT_INT32, + tensorflow::TensorShape({kDim1, 1})); + for (int i = 0; i < kDim1; ++i) { + x1.flat()(i) = i + 1; + y1.flat()(i) = i + 1; + } + std::vector inputs1{x1, y1}; + + constexpr int kDim2 = 4; + tensorflow::Tensor x2(tensorflow::DT_INT32, + tensorflow::TensorShape({1, kDim2})); + tensorflow::Tensor y2(tensorflow::DT_INT32, + tensorflow::TensorShape({kDim2, 1})); + for (int i = 0; i < kDim2; ++i) { + x2.flat()(i) = i + 1; + y2.flat()(i) = i + 1; + } + std::vector inputs2{x2, y2}; + + std::vector outputs1, outputs2; + for (int i = 0; i < 3; i++) { + TF_ASSERT_OK_AND_ASSIGN(outputs1, + executable.Execute(absl::MakeSpan(inputs1))); + TF_ASSERT_OK_AND_ASSIGN(outputs2, + executable.Execute(absl::MakeSpan(inputs2))); + } + ASSERT_EQ(outputs1.size(), 1); + ASSERT_EQ(outputs1[0].dtype(), tensorflow::DT_INT32); + ASSERT_EQ(outputs1[0].shape(), tensorflow::TensorShape({1, 1})); + ASSERT_EQ(outputs1[0].flat()(0), 14); + + ASSERT_EQ(outputs2.size(), 1); + ASSERT_EQ(outputs2[0].dtype(), tensorflow::DT_INT32); + ASSERT_EQ(outputs2[0].shape(), tensorflow::TensorShape({1, 1})); + ASSERT_EQ(outputs2[0].flat()(0), 30); + + ASSERT_EQ(executable.num_executables(), 2); +} + +} // namespace +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.cc b/tensorflow/core/tfrt/ifrt/sharding_utils.cc new file mode 100644 index 00000000000000..c61d528897bb1b --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.cc @@ -0,0 +1,393 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/tfrt/ifrt/sharding_utils.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tpu/kernels/sharding_utils.h" +#include "tsl/concurrency/ref_count.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { +absl::StatusOr ToIfrtDType( + tensorflow::DataType tensor_dtype) { + xla::PrimitiveType primitive_type; + TF_RETURN_IF_ERROR( + tensorflow::DataTypeToPrimitiveType(tensor_dtype, &primitive_type)); + return xla::ifrt::ToDType(primitive_type); +} + +// Shard the given `input_tensor` into equal shapes of slices. +// +// `num_paritions_per_axis` specifies the number of partitions along +// each axis (dimension). +// +// `num_replicas` specifies the number of replication for each partitioned +// sliced buffer. +// +// `devices` contains a list of devices flattend into the following +// order: [slice0][replicate0], [slice0][replicate1], ..., [slice1][replicate0], +// [slice1][replicate1], ... +absl::StatusOr>> +SplitAndCreateArraysFromHostBuffer( + xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, + const std::vector& num_partitions_per_axis, int num_replicas, + const std::vector& devices, + const Eigen::ThreadPoolDevice& thread_pool_device) { + int64_t num_slices = 1; + for (auto k : num_partitions_per_axis) { + num_slices *= k; + } + + tensorflow::DataType tensor_data_type = input_tensor.dtype(); + std::vector paddings(num_partitions_per_axis.size(), 0); + std::vector split_tensors; + split_tensors.resize(num_slices); + + auto allocate_output_fn = + [&](int i, const tensorflow::TensorShape& output_slice_shape, + tensorflow::Tensor** tensor) { + if (i < 0 || i >= split_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", split_tensors.size(), "]")); + } + split_tensors[i] = + tensorflow::Tensor(tensor_data_type, output_slice_shape); + *tensor = &split_tensors[i]; + return absl::OkStatus(); + }; + + // Fast path for output in the simple no split case. + auto assign_or_copy_value_fn = + [&](const tensorflow::Tensor& input) -> Status { + split_tensors[0] = input; + return absl::OkStatus(); + }; + + // XlaNDSplitter only support rank (0, 8] as there is no concept of split for + // rank 0 tensor. + if (input_tensor.shape().dims() == 0) { + if (split_tensors.size() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Rank 0 tensor only expects 1 slice but got ", split_tensors.size())); + } + split_tensors[0] = input_tensor; + } else { + switch (input_tensor.dtype()) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: { \ + TF_ASSIGN_OR_RETURN(auto splitter, \ + (XlaNDSplitter::Create( \ + num_partitions_per_axis, num_slices, paddings, \ + /*has_paddings=*/false))); \ + TF_RETURN_IF_ERROR( \ + splitter.Split(&input_tensor, "input tensor", assign_or_copy_value_fn, \ + allocate_output_fn, thread_pool_device)); \ + } break; + TF_CALL_ALL_TYPES(CASE); + TF_CALL_quint8(CASE); +#undef CASE + default: + return absl::InvalidArgumentError("Unsupported data type"); + } + } + + if (split_tensors.size() * num_replicas != devices.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Expect ", devices.size(), " but got ", + split_tensors.size(), " x ", num_replicas)); + } + + std::vector> arrays; + arrays.reserve(devices.size()); + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, ToIfrtDType(tensor_data_type)); + auto device_iter = devices.begin(); + for (int slice_idx = 0; slice_idx < split_tensors.size(); ++slice_idx) { + auto& tensor = split_tensors[slice_idx]; + + for (int i = 0; i < num_replicas; ++i) { + VLOG(2) << "Make array for buffer slice " << slice_idx << " at " + << tensor.data(); + if (device_iter == devices.end()) { + return absl::InternalError( + absl::StrCat("Missing Device ", i, " for slice ", slice_idx)); + } + auto single_device_sharding = xla::ifrt::SingleDeviceSharding::Create( + *device_iter, xla::ifrt::MemoryKind()); + + TF_ASSIGN_OR_RETURN( + auto array, + ifrt_client.MakeArrayFromHostBuffer( + tensor.data(), dtype, + xla::ifrt::Shape(tensor.shape().dim_sizes()), + /*byte_strides=*/{}, std::move(single_device_sharding), + xla::ifrt::Client::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [tensor, slice_idx]() { + // Keep tensor alive + LOG(INFO) << "Done with host buffer for slice " << slice_idx + << " at " << tensor.data(); + })); + arrays.push_back(std::move(array)); + device_iter++; + } + } + return arrays; +} + +absl::StatusOr VerifyIndexDomainsAndGetReplicas( + absl::Span index_domains, + const tensorflow::TensorShape& tensor_shape) { + if (index_domains.size() <= 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect multiple index domains but got ", index_domains.size())); + } + + for (auto index_domain = index_domains.begin(); + index_domain < index_domains.end(); ++index_domain) { + if (index_domain->shape().dims().size() != tensor_shape.dims()) { + return absl::InvalidArgumentError( + absl::StrCat("Expect equal rank of ", tensor_shape.dims(), + " but got ", index_domain->shape().dims().size())); + } + } + + // Only support equal shape for all index domains + auto first_index_domain = index_domains.begin(); + for (auto index_domain = index_domains.begin() + 1; + index_domain < index_domains.end(); ++index_domain) { + if (first_index_domain->shape() != index_domain->shape()) { + return absl::UnimplementedError(absl::StrCat( + "Expect equal shape of ", first_index_domain->shape().DebugString(), + " but got ", index_domain->shape().DebugString())); + } + } + + // Verify that each `IndexDomain` appear the same `num_replica` times. Since + // shapes are the same for all `IndexDomain`, this also implies each `origin` + // appear `num_replica` times. + auto index_domain_lexicographical_comparator = + [](const xla::ifrt::IndexDomain& a, const xla::ifrt::IndexDomain& b) { + return std::lexicographical_compare( + a.origin().elements().begin(), a.origin().elements().end(), + b.origin().elements().begin(), b.origin().elements().end()); + }; + absl::btree_map + index_domain_counts; + for (const auto& index_domain : index_domains) { + index_domain_counts[index_domain]++; + } + + std::vector unique_index_domains; + unique_index_domains.reserve(index_domain_counts.size()); + int num_replicas = index_domain_counts.begin()->second; + for (const auto& [index_domain, count] : index_domain_counts) { + if (count != num_replicas) { + return absl::FailedPreconditionError(absl::StrCat( + "Expected ", num_replicas, " replicas for ", + index_domain.DebugString(), " but got ", count, " replicas")); + } + unique_index_domains.push_back(index_domain); + } + + // Verify that distances of between origins of neighbouring `IndexDomain` + // bounded by shape. Note that unique_indexx_domains are already in sorted + // order. + auto prev_iter = unique_index_domains.begin(); + auto next_iter = unique_index_domains.begin() + 1; + const auto& bounded_box = first_index_domain->shape(); + while (prev_iter != unique_index_domains.end() && + next_iter != unique_index_domains.end()) { + xla::ifrt::Index offset = next_iter->origin() - prev_iter->origin(); + for (int dim = 0; dim < bounded_box.dims().size(); ++dim) { + if (std::abs(offset.elements()[dim]) != bounded_box.dims()[dim] && + offset.elements()[dim] != 0) { + return absl::FailedPreconditionError(absl::StrCat( + "IndexDomains should not have gap or overlap, but got ", + prev_iter->DebugString(), " and ", next_iter->DebugString(), + " that have offset of ", offset.DebugString())); + } + } + prev_iter = next_iter; + next_iter++; + } + + // Verify the last `IndexDomain`'s upper end of the bound matches with the + // tensor shape. Together with the above check, this provides an approximation + // to the following two assumptions: + // 1. the union of all IndexDomain covers the entire global shape array with + // no gaps. + // 2. no two index_domain have any overlap. + std::vector bounded_shape; + const auto& last_index_domain = unique_index_domains.back(); + bounded_shape.reserve(last_index_domain.shape().dims().size()); + for (int d = 0; d < last_index_domain.shape().dims().size(); ++d) { + bounded_shape.push_back(last_index_domain.origin().elements()[d] + + last_index_domain.shape().dims()[d]); + } + + if (xla::ifrt::Shape(bounded_shape) != + xla::ifrt::Shape(tensor_shape.dim_sizes())) { + return absl::FailedPreconditionError(absl::StrCat( + "IndexDomain ", last_index_domain.DebugString(), + " does not overlap with tensor shape ", tensor_shape.DebugString())); + } + + return num_replicas; +} + +} // namespace + +StatusOr> MakeAssembledArrayFromHostBuffer( + xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, + std::shared_ptr sharding, + const Eigen::ThreadPoolDevice& thread_pool_device) { + VLOG(2) << "Assembling arrays by sharding " << sharding->DebugString(); + + TF_ASSIGN_OR_RETURN(auto index_domains, + sharding->IndexDomains( + xla::ifrt::Shape(input_tensor.shape().dim_sizes()))); + + TF_ASSIGN_OR_RETURN(int index_domain_replicas, + VerifyIndexDomainsAndGetReplicas( + absl::MakeSpan(index_domains), input_tensor.shape())); + + const auto& first_index_domain = index_domains.begin(); + std::vector num_partitions_per_axis; + int total_num_partitions = 1; + num_partitions_per_axis.reserve(input_tensor.shape().dims()); + for (int dim = 0; dim < input_tensor.shape().dims(); ++dim) { + int target_size = first_index_domain->shape().dims()[dim]; + if (input_tensor.shape().dim_size(dim) % target_size != 0) { + return absl::FailedPreconditionError(absl::StrCat( + "Only support even sharding, but input tensor shape ", + input_tensor.shape().DebugString(), " not even splittable to ", + first_index_domain->shape().DebugString())); + } + int num_partitions = input_tensor.shape().dim_size(dim) / target_size; + total_num_partitions *= num_partitions; + num_partitions_per_axis.push_back(num_partitions); + } + + if (total_num_partitions > sharding->devices().size() || + sharding->devices().size() % total_num_partitions != 0) { + return absl::UnimplementedError(absl::StrCat( + "Number of devices ", sharding->devices().size(), + " not a multiple of number of partitions", total_num_partitions)); + } + + // Assume index domains are non-overlapping and each index domain appears + // exactly num_replicates times. This allows us to rely on + // lexicographical sorting to replicate slices in the correct order. + int num_replicas = sharding->devices().size() / total_num_partitions; + if (index_domain_replicas != num_replicas) { + return absl::FailedPreconditionError( + absl::StrCat("IndexDomain indicates ", index_domain_replicas, + " replicas, but got ", num_replicas, " replicas")); + } + + // Sorted the IndexDomain and devices from major to minor dimenson. For + // example, a two dimension IndexDomain will be ordered by [0, 0], [0, 1], [1, + // 0], [1, 1]. + // This is O(n*log(n)) vs looking for devices individually which is O(n^2). + struct IndexDomainDevice { + xla::ifrt::IndexDomain index_domain; + xla::ifrt::Device* device; + // The index of this `device`/`index_domain` in the + // sharding.devices/index_domains. + int original_shard_index; + }; + std::vector index_domain_devices; + index_domain_devices.reserve(index_domains.size()); + for (int i = 0; i < index_domains.size(); ++i) { + index_domain_devices.push_back( + {index_domains[i], sharding->devices()[i], i}); + } + std::sort(index_domain_devices.begin(), index_domain_devices.end(), + [](const IndexDomainDevice& a, const IndexDomainDevice& b) { + return std::lexicographical_compare( + a.index_domain.origin().elements().begin(), + a.index_domain.origin().elements().end(), + b.index_domain.origin().elements().begin(), + b.index_domain.origin().elements().end()); + }); + // Now the devices is in order. + std::vector devices; + devices.reserve(index_domain_devices.size()); + std::vector original_device_indices; + original_device_indices.reserve(index_domain_devices.size()); + for (auto& [index_domain, device, original_device_index] : + index_domain_devices) { + devices.push_back(device); + original_device_indices.push_back(original_device_index); + VLOG(3) << "Device " << device->ToString(); + } + + TF_ASSIGN_OR_RETURN(auto arrays, + SplitAndCreateArraysFromHostBuffer( + ifrt_client, input_tensor, num_partitions_per_axis, + num_replicas, devices, thread_pool_device)); + + // Re-arranged arrays back to original device order + std::vector> rearranged_arrays; + rearranged_arrays.resize(arrays.size()); + for (int i = 0; i < arrays.size(); ++i) { + rearranged_arrays[original_device_indices[i]] = std::move(arrays[i]); + } + + return ifrt_client.AssembleArrayFromSingleDeviceArrays( + xla::ifrt::Shape(input_tensor.shape().dim_sizes()), sharding, + absl::MakeSpan(rearranged_arrays), + xla::ifrt::ArrayCopySemantics::kDonateInput); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.h b/tensorflow/core/tfrt/ifrt/sharding_utils.h new file mode 100644 index 00000000000000..5e19d46443582c --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_SHARDING_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_SHARDING_UTILS_H_ + +#include + +#include "xla/executable_run_options.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/sharding.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/statusor.h" +#include "tsl/concurrency/ref_count.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Sharded the given `data` by the `sharding` specification. +// It currently supports even sharding, replication and partial replication. +StatusOr> MakeAssembledArrayFromHostBuffer( + xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, + std::shared_ptr sharding, + const Eigen::ThreadPoolDevice& thread_pool_device); +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_SHARDING_UTILS_H_ diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc new file mode 100644 index 00000000000000..57fd9ed7003bb7 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc @@ -0,0 +1,516 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/tfrt/ifrt/sharding_utils.h" + +#include +#include +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "llvm/ADT/SmallVector.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_matcher.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +using tsl::testing::StatusIs; + +struct HloShardingTestParam { + tensorflow::Tensor in_tensor; + std::vector expected_out_tensors; + std::vector device_indices; + xla::HloSharding sharding; +}; + +struct ShardingParamTestParam { + tensorflow::Tensor in_tensor; + std::vector expected_out_tensors; + std::vector device_indices; + + // Parameter to form ShardingParam + std::vector dim_shards; + llvm::SmallVector permutation; + llvm::SmallVector axis_sizes; +}; + +using ShardingParamTest = ::testing::TestWithParam; +using HloShardingTest = ::testing::TestWithParam; + +// Wrapper functions for building sharding specs for a given shape with a +// natural device order. +xla::HloSharding Tile(absl::Span dims) { + return xla::HloSharding::IotaTile(dims); +} +xla::HloSharding PartialTile(absl::Span dims) { + return xla::HloSharding::PartialTile(xla::TileAssignment(dims)); +} +xla::HloSharding Replicate() { return xla::HloSharding::Replicate(); } + +TEST_P(HloShardingTest, MakeAssembledArrayFromHostBuffer) { + constexpr int kMaxParallelism = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); + + Eigen::ThreadPoolDevice device(thread_pool->AsEigenThreadPool(), + kMaxParallelism); + + auto input_tensor = GetParam().in_tensor; + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + TF_ASSERT_OK_AND_ASSIGN(auto device_list, + xla::ifrt::test_util::GetDevices( + client.get(), GetParam().device_indices)); + + auto sharding = xla::ifrt::HloSharding::Create( + device_list, xla::ifrt::MemoryKind(), GetParam().sharding); + + TF_ASSERT_OK_AND_ASSIGN( + auto assembled_array, + MakeAssembledArrayFromHostBuffer(*client, input_tensor, + std::move(sharding), device)); + + TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, + assembled_array->DisassembleIntoSingleDeviceArrays( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); + + ASSERT_EQ(disassembled_arrays.size(), GetParam().expected_out_tensors.size()); + + tensorflow::Tensor host_tensor(tensorflow::DT_INT32, + tensorflow::TensorShape({1, 2})); + + for (int i = 0; i < disassembled_arrays.size(); ++i) { + SCOPED_TRACE(absl::StrCat("Array ", i, " of ", disassembled_arrays.size())); + auto disassembled_array = disassembled_arrays[i]; + auto expected_out_tensor = GetParam().expected_out_tensors[i]; + ASSERT_EQ(disassembled_array->shape(), + xla::ifrt::Shape(expected_out_tensor.shape().dim_sizes())); + tensorflow::Tensor host_tensor(expected_out_tensor.dtype(), + expected_out_tensor.shape()); + TF_ASSERT_OK( + disassembled_array + ->CopyToHostBuffer(host_tensor.data(), /*byte_strides=*/{}, + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .Await()); + EXPECT_THAT(expected_out_tensor, tensorflow::test::TensorEq(host_tensor)); + } +} + +INSTANTIATE_TEST_SUITE_P( + HloShardingTests, HloShardingTest, + ::testing::ValuesIn( + { + // Full replication. + { + .in_tensor = test::AsTensor({1}, TensorShape({})), + .expected_out_tensors = + { + test::AsTensor({1}, TensorShape({})), + test::AsTensor({1}, TensorShape({})), + }, + .device_indices = {0, 1}, + .sharding = Replicate(), + }, + { + .in_tensor = test::AsTensor({1, 2, 3}, + TensorShape({3, 1})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 3}, TensorShape({3, 1})), + test::AsTensor({1, 2, 3}, TensorShape({3, 1})), + }, + .device_indices = {0, 1}, + .sharding = Replicate(), + }, + // 1-D sharding + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({4})), + .expected_out_tensors = + { + test::AsTensor({1, 2}, TensorShape({2})), + test::AsTensor({3, 4}, TensorShape({2})), + }, + .device_indices = {0, 1}, + .sharding = Tile({2}), + }, + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + }, + .device_indices = {0, 1}, + .sharding = Tile({2, 1}), + }, + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({1, 2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 3}, TensorShape({1, 2, 1})), + test::AsTensor({2, 4}, TensorShape({1, 2, 1})), + }, + .device_indices = {0, 1}, + .sharding = Tile({1, 1, 2}), + }, + { + .in_tensor = test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, + TensorShape({4, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + test::AsTensor({5, 6}, TensorShape({1, 2})), + test::AsTensor({7, 8}, TensorShape({1, 2})), + }, + .device_indices = {0, 1, 2, 3}, + .sharding = Tile({4, 1}), + }, + { + .in_tensor = test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, + TensorShape({4, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 3, 5, 7}, + TensorShape({4, 1})), + test::AsTensor({2, 4, 6, 8}, + TensorShape({4, 1})), + }, + .device_indices = {0, 1}, + .sharding = Tile({1, 2}), + }, + // 2-D sharding + { + .in_tensor = test::AsTensor( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + TensorShape({4, 4})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({3, 4, 7, 8}, + TensorShape({2, 2})), + test::AsTensor({9, 10, 13, 14}, + TensorShape({2, 2})), + test::AsTensor({11, 12, 15, 16}, + TensorShape({2, 2})), + }, + .device_indices = {0, 1, 2, 3}, + .sharding = Tile({2, 2}), + }, + { + .in_tensor = test::AsTensor( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + TensorShape({4, 1, 4})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 1, 2})), + test::AsTensor({3, 4, 7, 8}, + TensorShape({2, 1, 2})), + test::AsTensor({9, 10, 13, 14}, + TensorShape({2, 1, 2})), + test::AsTensor({11, 12, 15, 16}, + TensorShape({2, 1, 2})), + }, + .device_indices = {0, 1, 2, 3}, + .sharding = Tile({2, 1, 2}), + }, + // Partial replication + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + }, + .device_indices = {0, 1, 2, 3}, + .sharding = PartialTile({1, 2, 2}), + }, + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + }, + .device_indices = {0, 1, 2, 3}, + .sharding = PartialTile({2, 1, 2}), + }, + })); + +TEST_P(ShardingParamTest, MakeAssembledArrayFromHostBuffer) { + constexpr int kMaxParallelism = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); + + Eigen::ThreadPoolDevice device(thread_pool->AsEigenThreadPool(), + kMaxParallelism); + + auto input_tensor = GetParam().in_tensor; + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + TF_ASSERT_OK_AND_ASSIGN(auto device_list, + xla::ifrt::test_util::GetDevices( + client.get(), GetParam().device_indices)); + + xla::ifrt::ShardingParam sharding_param{ + GetParam().dim_shards, + xla::ifrt::ShardingParam::MinorToMajor(GetParam().permutation, + GetParam().axis_sizes)}; + + TF_ASSERT_OK_AND_ASSIGN( + auto sharding, xla::ifrt::ShardingParamSharding::Create( + sharding_param, device_list, xla::ifrt::MemoryKind())); + + TF_ASSERT_OK_AND_ASSIGN( + auto assembled_array, + MakeAssembledArrayFromHostBuffer(*client, input_tensor, + std::move(sharding), device)); + + TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays, + assembled_array->DisassembleIntoSingleDeviceArrays( + xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); + + ASSERT_EQ(disassembled_arrays.size(), GetParam().expected_out_tensors.size()); + + tensorflow::Tensor host_tensor(tensorflow::DT_INT32, + tensorflow::TensorShape({1, 2})); + + for (int i = 0; i < disassembled_arrays.size(); ++i) { + SCOPED_TRACE(absl::StrCat("Array ", i, " of ", disassembled_arrays.size())); + auto disassembled_array = disassembled_arrays[i]; + auto expected_out_tensor = GetParam().expected_out_tensors[i]; + ASSERT_EQ(disassembled_array->shape(), + xla::ifrt::Shape(expected_out_tensor.shape().dim_sizes())); + tensorflow::Tensor host_tensor(expected_out_tensor.dtype(), + expected_out_tensor.shape()); + TF_ASSERT_OK( + disassembled_array + ->CopyToHostBuffer(host_tensor.data(), /*byte_strides=*/{}, + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .Await()); + EXPECT_THAT(expected_out_tensor, tensorflow::test::TensorEq(host_tensor)); + } +} + +INSTANTIATE_TEST_SUITE_P( + ShardingParamTests, ShardingParamTest, + ::testing::ValuesIn( + { + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + + .expected_out_tensors = + { + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + }, + .device_indices = {0, 1}, + .dim_shards = {2, 1}, + .permutation = {0, 1}, + .axis_sizes = {2, 1}, + }, + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + }, + .device_indices = {0, 1}, + .dim_shards = {1, 2}, + .permutation = {0, 1}, + .axis_sizes = {1, 2}, + }, + { + .in_tensor = test::AsTensor( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + TensorShape({4, 4})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({3, 4, 7, 8}, + TensorShape({2, 2})), + test::AsTensor({9, 10, 13, 14}, + TensorShape({2, 2})), + test::AsTensor({11, 12, 15, 16}, + TensorShape({2, 2})), + }, + .device_indices = {0, 1, 2, 3}, + .dim_shards = {2, 2}, + .permutation = {0, 1}, + .axis_sizes = {2, 2}, + }, + { + .in_tensor = test::AsTensor( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + TensorShape({4, 4})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 3, 4, 5, 6, 7, 8}, + TensorShape({2, 4})), + test::AsTensor({9, 10, 11, 12, 13, 14, 15, 16}, + TensorShape({2, 4})), + }, + .device_indices = {0, 1}, + .dim_shards = {2, 1}, + .permutation = {1, 0}, + .axis_sizes = {2, 1}, + }, + // Full replication + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + }, + .device_indices = {0, 1}, + .dim_shards = {1, 1}, + .permutation = {0}, + .axis_sizes = {2}, + }, + // Partial replication (aka replicate_on_last_tile_dim = true) + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + }, + .device_indices = {0, 1, 2, 3}, + .dim_shards = {1, 2}, + .permutation = {0, 1}, + .axis_sizes = {2, 2}, + }, + // Partial replication that shards along the first dimension. + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({1, 2}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + test::AsTensor({3, 4}, TensorShape({1, 2})), + }, + .device_indices = {0, 1, 2, 3}, + .dim_shards = {2, 1}, + .permutation = {0, 1}, + .axis_sizes = {2, 2}, + }, + // Partial replication with random device indices. + { + .in_tensor = test::AsTensor({1, 2, 3, 4}, + TensorShape({2, 2})), + .expected_out_tensors = + { + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({1, 3}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + test::AsTensor({2, 4}, TensorShape({2, 1})), + }, + .device_indices = {3, 1, 2, 0}, + .dim_shards = {1, 2}, + .permutation = {0, 1}, + .axis_sizes = {2, 2}, + }, + })); + +TEST(ShardingUtilsTest, MismatchRank) { + constexpr int kMaxParallelism = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); + + Eigen::ThreadPoolDevice device(thread_pool->AsEigenThreadPool(), + kMaxParallelism); + + auto input_tensor = + test::AsTensor({1, 2, 3, 4}, TensorShape({2, 1, 2})); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + TF_ASSERT_OK_AND_ASSIGN( + auto device_list, xla::ifrt::test_util::GetDevices(client.get(), {0, 1})); + + xla::ifrt::ShardingParam sharding_param = { + /*dim_shards=*/{2, 1}, + xla::ifrt::ShardingParam::MinorToMajor(/*permutation=*/{0, 1}, + /*axis_sizes=*/{2, 1})}; + + TF_ASSERT_OK_AND_ASSIGN( + auto sharding, xla::ifrt::ShardingParamSharding::Create( + sharding_param, device_list, xla::ifrt::MemoryKind())); + + EXPECT_THAT(MakeAssembledArrayFromHostBuffer(*client, input_tensor, + std::move(sharding), device), + StatusIs(absl::StatusCode::kInvalidArgument, + "Expect equal rank of 3 but got 2")); +} + +} // namespace +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/testdata/BUILD b/tensorflow/core/tfrt/ifrt/testdata/BUILD new file mode 100644 index 00000000000000..948ce54ab983a7 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/testdata/BUILD @@ -0,0 +1,12 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/core/tfrt/ifrt:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "testdata", + srcs = glob( + ["*"], + ), +) diff --git a/tensorflow/core/tfrt/ifrt/testdata/executable.mlir b/tensorflow/core/tfrt/ifrt/testdata/executable.mlir new file mode 100644 index 00000000000000..95c558ddb7ae0b --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/testdata/executable.mlir @@ -0,0 +1,6 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.MatMul"(%arg0, %arg1): (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> + } +} \ No newline at end of file diff --git a/tensorflow/core/tfrt/kernels/BUILD b/tensorflow/core/tfrt/kernels/BUILD index bf2768ec3ed419..390bef2009b9b0 100644 --- a/tensorflow/core/tfrt/kernels/BUILD +++ b/tensorflow/core/tfrt/kernels/BUILD @@ -16,6 +16,23 @@ package_group( ], ) +cc_library( + name = "ifrt_program_ops", + srcs = ["ifrt_program_ops.cc"], + hdrs = ["ifrt_program_ops.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core/tfrt/ifrt:ifrt_executable_registry", + "//tensorflow/core/tfrt/ifrt:ifrt_serving_executable", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_library( name = "stream_ops", srcs = ["stream_ops.cc"], diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc new file mode 100644 index 00000000000000..92ce3cad2c1e04 --- /dev/null +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/kernels/ifrt_program_ops.h" + +#include +#include + +#include "absl/base/call_once.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" + +namespace tensorflow { +namespace tfrt_stub { + +IfrtCallOp::IfrtCallOp(tensorflow::OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("program_id", &program_id_)); +} + +void IfrtCallOp::Compute(tensorflow::OpKernelContext* ctx) { + absl::call_once(init_once_, [&]() { + executable_ = tensorflow::ifrt_serving::ServingExecutableRegistry::Lookup( + program_id_); + }); + OP_REQUIRES(ctx, executable_ != nullptr, + absl::NotFoundError( + absl::StrCat("Unknown program id '", program_id_, "'"))); + + std::vector inputs; + inputs.reserve(ctx->num_inputs()); + for (int i = 0; i < ctx->num_inputs(); ++i) { + inputs.push_back(ctx->input(i)); + } + + absl::StatusOr> results = executable_->Execute(inputs); + OP_REQUIRES(ctx, results.ok(), results.status()); + + tensorflow::OpOutputList outputs(ctx, 0, results->size()); + for (int i = 0; i < results->size(); ++i) { + outputs.set(i, (*results)[i]); + } +} + +REGISTER_KERNEL_BUILDER(Name("IfrtCall").Device(tensorflow::DEVICE_CPU), + IfrtCallOp); + +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops.h b/tensorflow/core/tfrt/kernels/ifrt_program_ops.h new file mode 100644 index 00000000000000..578ccae70b8e4b --- /dev/null +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_KERNELS_IFRT_PROGRAM_OPS_H_ +#define TENSORFLOW_CORE_TFRT_KERNELS_IFRT_PROGRAM_OPS_H_ + +#include + +#include + +#include "absl/base/call_once.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" + +namespace tensorflow { +namespace tfrt_stub { + +// TensorFlow op that calls a Ifrt program registered in `ProgramRegistry`. +class IfrtCallOp : public tensorflow::OpKernel { + public: + explicit IfrtCallOp(tensorflow::OpKernelConstruction* ctx); + + IfrtCallOp(const IfrtCallOp& other) = delete; + IfrtCallOp& operator=(const IfrtCallOp& other) = delete; + + void Compute(tensorflow::OpKernelContext* ctx) override; + + private: + // Op attributes. + int64_t program_id_; + + // Ifrt program to be called. Cached after the first call. + absl::once_flag init_once_; + std::shared_ptr executable_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow +#endif // TENSORFLOW_CORE_TFRT_KERNELS_IFRT_PROGRAM_OPS_H_ diff --git a/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h b/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h index f6b8de5da15dcb..f82666f172a37d 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h +++ b/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h @@ -109,6 +109,8 @@ class Buffer { size_t size() const { return buffer_.size(); } bool empty() const { return buffer_.empty(); } + void shrink_to_fit() { buffer_.shrink_to_fit(); } + private: static_assert(alignof(std::max_align_t) >= 8, "The bytecode buffer needs to be at least 8-byte aligned."); diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD index 9dacee29030c8f..cca6cddbfb650a 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/BUILD +++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD @@ -10,7 +10,6 @@ package( # copybara:uncomment "//learning/brain/tfrt:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/core/tfrt/graph_executor:__subpackages__", - "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests:__subpackages__", "//tensorflow/core/tfrt/saved_model:__subpackages__", "//tensorflow/core/tfrt/tfrt_session:__subpackages__", ], diff --git a/tensorflow/core/tfrt/ops/BUILD b/tensorflow/core/tfrt/ops/BUILD index cf9beb34fc7231..d09db6a81dc3c2 100644 --- a/tensorflow/core/tfrt/ops/BUILD +++ b/tensorflow/core/tfrt/ops/BUILD @@ -5,6 +5,32 @@ package( default_visibility = ["//tensorflow/core/tfrt/__subpackages__"], ) +tf_gen_op_libs( + op_lib_names = ["ifrt_program_ops"], + sub_directory = "", + deps = ["//tensorflow/core:lib"], +) + +tf_gen_op_wrapper_cc( + name = "gen_ifrt_program_ops", + out_ops_file = "gen_ifrt_program_ops", + deps = [":ifrt_program_ops_op_lib"], +) + +cc_library( + name = "gen_ifrt_program_ops_cc", + srcs = ["gen_ifrt_program_ops.cc"], + hdrs = ["gen_ifrt_program_ops.h"], + deps = [ + ":ifrt_program_ops_op_lib", + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + tf_gen_op_libs( op_lib_names = ["stream_ops"], sub_directory = "", diff --git a/tensorflow/core/tfrt/ops/ifrt_program_ops.cc b/tensorflow/core/tfrt/ops/ifrt_program_ops.cc new file mode 100644 index 00000000000000..ab8e14b2e41eac --- /dev/null +++ b/tensorflow/core/tfrt/ops/ifrt_program_ops.cc @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { +namespace tfrt_stub { + +REGISTER_OP("IfrtCall") + .Input("args: Tin") + .Output("results: Tout") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .Attr("program_id: int") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::UnknownShape) + .Doc(R"( +Calls an IFRT program identified by the given program id. + +This op looks up a `ServingExecutable` from `ServingExecutableRegistry` using +the program id, calls the executable with the op's inputs as arguments, and +returns its results as the op's outputs. + +Note that this op is not part of a stable interface. Users must not use this op +in their SavedModel and instead rely on Ifrt Serving's mechanism that +automatically inserts this op with graph rewrite. + +program_id: int64 id that can be used to look up compiled programs from + `ServingExecutableRegistry`. +)"); + +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index 17b1ba31fd044c..dc225b9ab2448d 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc @@ -320,9 +320,8 @@ StatusOr AotCompileToGpuPjRtLoadedExecutableWithDevice( int graph_def_version, const std::vector& args, bool has_ref_vars, bool may_alias_resource_update, XlaCompiler::CompilationResult** compilation_result) { - TF_ASSIGN_OR_RETURN(auto client, xla::GetStreamExecutorGpuClient( - true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSIGN_OR_RETURN(auto client, + xla::GetStreamExecutorGpuClient(xla::GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast(client.release())); diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD index 3495c1150ac420..12b9779fef43f1 100644 --- a/tensorflow/core/tfrt/saved_model/tests/BUILD +++ b/tensorflow/core/tfrt/saved_model/tests/BUILD @@ -619,6 +619,64 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "saved_model_ifrt_testlib", + testonly = 1, + srcs = ["saved_model_ifrt_test.cc"], + data = [ + "toy_v2/saved_model.pb", + "toy_v2/variables/variables.data-00000-of-00001", + "toy_v2/variables/variables.index", + ], + tags = ["no_oss"], + deps = [ + "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_backend_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", + "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", + "//tensorflow/core/tfrt:ifrt_program_ops_op_lib", + "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/runtime", + "//tensorflow/core/tfrt/saved_model:saved_model_cpu", + "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + "@tf_runtime//:basic_kernels_alwayslink", + "@tf_runtime//:core_runtime_alwayslink", + "@tf_runtime//:test_kernels_alwayslink", + "@tf_runtime//backends/cpu:core_runtime_alwayslink", + "@tf_runtime//backends/cpu:tf_ops_alwayslink", + ], +) + +tf_cc_test( + name = "saved_model_ifrt_test", + srcs = [], + tags = ["no_oss"], + deps = [ + ":saved_model_ifrt_testlib", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "saved_model_ifrt_test_mlrt", + srcs = [], + args = ["--enable_mlrt=true"], + tags = ["no_oss"], + deps = [ + ":saved_model_ifrt_testlib", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "saved_model_test", srcs = [], diff --git a/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl b/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl index 0cb693e7b8763c..f3ed254c39689e 100644 --- a/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl +++ b/tensorflow/core/tfrt/saved_model/tests/gen_saved_model.bzl @@ -18,3 +18,18 @@ def gen_saved_model(model_name = "", script = "", **kwargs): tools = [script], **kwargs ) + +def gen_variableless_saved_model(model_name = "", script = "", **kwargs): + native.genrule( + name = "saved_model_gen_" + model_name, + srcs = [], + outs = [ + model_name + "/saved_model.pb", + ], + cmd = if_google( + "$(location " + script + ") --saved_model_path=$(RULEDIR)/" + model_name, + "touch $(OUTS)", # TODO(b/188517768): fix model gen. + ), + tools = [script], + **kwargs + ) diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc new file mode 100644 index 00000000000000..6113741df2eec4 --- /dev/null +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/saved_model/saved_model.h" +#include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tfrt_stub { +namespace { + +TEST(SavedModelIfrt, Basic) { + std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( + "tensorflow/core/tfrt/saved_model/tests/toy_v2"); + + auto runtime = + tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + // Use IFRT compiler + runtime->AddCreateRuntimeResourceFn( + [&](tensorflow::tfrt_stub::ModelRuntimeContext& model_context) { + tensorflow::ifrt_serving::IfrtModelContext ifrt_model_context(client); + + model_context.resource_context() + .CreateResource( + "IfrtModelContext", std::move(ifrt_model_context)); + return absl::OkStatus(); + }); + tensorflow::ifrt_serving::IfrtBackendCompiler ifrt_compiler; + + auto options = DefaultSavedModelOptions(runtime.get()); + options.enable_lazy_loading = true; + options.lazy_loading_use_graph_executor = true; + options.graph_execution_options.compile_options.backend_compiler = + &ifrt_compiler; + + TF_ASSERT_OK_AND_ASSIGN( + auto saved_model, SavedModelImpl::LoadSavedModel(options, saved_model_dir, + /*tags=*/{"serve"})); + + // Set input 'x' to [[1, 1, 1]] + std::vector inputs; + inputs.push_back( + CreateTfTensor(/*shape=*/{1, 3}, /*data=*/{1, 1, 1})); + + tfrt::SavedModel::RunOptions run_options; + + std::vector outputs; + TF_ASSERT_OK( + saved_model->Run(run_options, "serving_default", inputs, &outputs)); + ASSERT_EQ(outputs.size(), 1); + + EXPECT_THAT(GetTfTensorData(outputs[0]), + ::testing::ElementsAreArray({6})); +} + +} // namespace +} // namespace tfrt_stub +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/utils/BUILD b/tensorflow/core/tfrt/utils/BUILD index f9879e84d4562d..7517e07928f87b 100644 --- a/tensorflow/core/tfrt/utils/BUILD +++ b/tensorflow/core/tfrt/utils/BUILD @@ -270,6 +270,7 @@ tf_cc_test( deps = [ ":fallback_tensor", "//tensorflow/core/common_runtime:dma_helper", + "//tensorflow/core/framework:tensor_shape", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/tfrt/utils/fallback_tensor.h b/tensorflow/core/tfrt/utils/fallback_tensor.h index 393e06e75f4a90..0856117d2b7a09 100644 --- a/tensorflow/core/tfrt/utils/fallback_tensor.h +++ b/tensorflow/core/tfrt/utils/fallback_tensor.h @@ -64,7 +64,7 @@ class FallbackTensor { FallbackTensor(const FallbackTensor& other) { *this = other; } FallbackTensor& operator=(const FallbackTensor& other) { tsl::profiler::TraceMe trace_me("FallbackTensor::Copy"); - if (!other.is_immutable()) { + if (!other.is_immutable() && other.buffer() != nullptr) { // Create a new TensorBuffer which contains a new atomic counter for each // result, to avoid downstream threads contending the original atomic // counter. @@ -72,7 +72,8 @@ class FallbackTensor { tensorflow::tfrt_stub::ImmutableTensor::Create(other.tensor()) .tensor()); } else { - // For immutable tensors, we just need to copy the pointer. + // For immutable tensors or empty tensors, we just need to copy the + // pointer as they don't incur atomic operations when they are referenced. tensor_ = other.tensor(); } is_immutable_ = true; diff --git a/tensorflow/core/tfrt/utils/fallback_tensor_test.cc b/tensorflow/core/tfrt/utils/fallback_tensor_test.cc index 9c54e8704158c3..1e3de50a38d9fc 100644 --- a/tensorflow/core/tfrt/utils/fallback_tensor_test.cc +++ b/tensorflow/core/tfrt/utils/fallback_tensor_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace tfrt_stub { @@ -158,6 +159,15 @@ TEST(FallbackTensorTest, FallbackTensorCopyRootBuffer) { tensorflow::DMAHelper::buffer(&tensor)); } +TEST(FallbackTensorTest, EmptyTensor) { + tensorflow::Tensor tensor(tensorflow::DT_FLOAT, + tensorflow::TensorShape({1, 0})); + + FallbackTensor fallback_tensor(tensor); + auto copy = fallback_tensor; + ASSERT_FALSE(copy.buffer()); +} + } // namespace } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 317060b2ca9408..eade2efb96f75c 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -1,19 +1,18 @@ # Description: Utilities for TPU Operations -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_libtpu", "if_windows", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/tf2xla:__subpackages__", - "//tensorflow/compiler/xrt:__subpackages__", "//tensorflow/core/tpu:__subpackages__", "//tensorflow/dtensor:__subpackages__", ], diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 212bfbcbdcb844..48b4711d37cf75 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -46,7 +46,7 @@ package_group( packages = [ "//tensorflow/compiler/mlir/quantization/...", "//tensorflow/compiler/mlir/tf2xla/...", - "//tensorflow/compiler/xrt/kernels/...", + "//tensorflow/core/tfrt/ifrt/...", "//tensorflow/core/tpu/...", "//tensorflow/dtensor/...", "//third_party/py/jax_tpu_embedding/...", @@ -60,13 +60,15 @@ tf_kernel_library( visibility = ["//visibility:public"], deps = [ ":cross_replica_ops", + ":global_iter_id_op", ":host_compute_ops", ":image_resize_ops", ":infeed_ops", ":outfeed_ops", ":replication_ops", ":sharding_util_ops", - ":sparse_core_ops", + ":sparse_core_preprocess_ops", + ":sparse_core_xla_ops", ":topk_ops", ":tpu_compile_op", ":tpu_configuration_ops", @@ -184,6 +186,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:statusor", @@ -198,7 +201,6 @@ cc_library( "@local_xla//xla/client:xla_computation", "@local_xla//xla/client/lib:slicing", "@local_xla//xla/stream_executor/tpu:c_api_decl", - "@local_xla//xla/stream_executor/tpu:status_helper", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", ], @@ -512,7 +514,6 @@ cc_library( ":tpu_program_group_interface", "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xrt:xrt_proto_cc", "//tensorflow/core:lib", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "@local_xla//xla:xla_proto_cc", @@ -1342,6 +1343,7 @@ cc_library( name = "sharding_util_ops", srcs = ["sharding_util_ops.cc"], deps = [ + ":sharding_utils", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core/framework:op_requires", @@ -1349,6 +1351,7 @@ cc_library( "//tensorflow/core/platform:refcount", "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1360,12 +1363,53 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "sharding_utils", + srcs = ["sharding_utils.cc"], + hdrs = ["sharding_utils.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:statusor", + ], +) + +tf_cc_test( + name = "sharding_utils_test", + srcs = ["sharding_utils_test.cc"], + deps = [ + ":sharding_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", + ], +) + tf_kernel_library( name = "global_iter_id_op", srcs = ["global_iter_id.cc"], deps = [ "//tensorflow/core:framework", "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/tpu/ops:sparse_core_ops", ], ) @@ -1479,52 +1523,6 @@ tf_cc_test( ], ) -cc_library( - name = "sparse_core_ops", - visibility = ["//visibility:public"], - deps = [ - ":sparse_core_preprocess_ops", - "//tensorflow/compiler/jit:xla_device", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_op_registry", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:functional_ops_op_lib", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:nn_ops_op_lib", - "//tensorflow/core:no_op_op_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:sendrecv_ops_op_lib", - "//tensorflow/core/kernels:ops_util_hdrs", - "//tensorflow/core/kernels:transpose_functor", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/tpu:tpu_configuration", - "//tensorflow/core/tpu/kernels:global_iter_id_op", - "//tensorflow/core/tpu/kernels:host_compute_ops", - "//tensorflow/core/tpu/kernels:image_resize_ops", - "//tensorflow/core/tpu/kernels:infeed_ops", - "//tensorflow/core/tpu/kernels:outfeed_ops", - "//tensorflow/core/tpu/kernels:replication_ops", - "//tensorflow/core/tpu/kernels:sharding_util_ops", - "//tensorflow/core/tpu/kernels:topk_ops", - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_interface", - "//tensorflow/core/tpu/kernels:tpu_functional_ops", - "//tensorflow/core/tpu/kernels:tpu_handle_to_key_op", - "//tensorflow/core/tpu/kernels:tpu_op_consts", - "//tensorflow/core/tpu/kernels:transfer_ops", - "//tensorflow/dtensor/cc:dtensor_tpu_kernels", - "@com_google_absl//absl/strings", - "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/stream_executor:multi_platform_manager", - ], -) - tf_proto_library( name = "sparse_core_layout_proto", srcs = ["sparse_core_layout.proto"], diff --git a/tensorflow/core/tpu/kernels/global_iter_id.cc b/tensorflow/core/tpu/kernels/global_iter_id.cc index 92a44d7c106a2b..11b80146f63153 100644 --- a/tensorflow/core/tpu/kernels/global_iter_id.cc +++ b/tensorflow/core/tpu/kernels/global_iter_id.cc @@ -29,7 +29,6 @@ class GlobalIterId : public OpKernel { ctx->set_output(0, Tensor(ctx->frame_iter().iter_id)); } }; -REGISTER_OP("GlobalIterId").Output("iter_id: int64").SetIsStateful(); REGISTER_KERNEL_BUILDER(Name("GlobalIterId").Device(DEVICE_CPU), GlobalIterId); } // anonymous namespace diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops.cc b/tensorflow/core/tpu/kernels/sharding_util_ops.cc index 5513547b6bb67b..fe726011527165 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops.cc @@ -15,11 +15,14 @@ limitations under the License. #include #include +#include +#include #include #include #define EIGEN_USE_THREADS +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -34,10 +37,12 @@ limitations under the License. #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tpu/kernels/sharding_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/macros.h" @@ -129,454 +134,63 @@ Status CreateResourceInvalidDTypeError(const ResourceHandle& handle, DataTypeString(expected_dtype), ".")); } -// Converts flatten index to start indices (subscript scaled with slice shape) -// for determining where to start a slice in the input tensor. -template -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); -template <> -Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, int index); - -template -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, - const int index) { - return Eigen::DSizes(); -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[0] = index * slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[1] = (index % num_partitions[1]) * slice_shape[1]; - subscript[0] = (index / num_partitions[1]) * slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[2] = (index % num_partitions[2]) * slice_shape[2]; - subscript[1] = - ((index / num_partitions[2]) % num_partitions[1]) * slice_shape[1]; - subscript[0] = - (index / (num_partitions[2] * num_partitions[1])) * slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[3] = (index % num_partitions[3]) * slice_shape[3]; - subscript[2] = - ((index / num_partitions[3]) % num_partitions[2]) * slice_shape[2]; - subscript[1] = - ((index / (num_partitions[3] * num_partitions[2])) % num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[3] * num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[4] = (index % num_partitions[4]) * slice_shape[4]; - subscript[3] = - ((index / num_partitions[4]) % num_partitions[3]) * slice_shape[3]; - subscript[2] = - ((index / (num_partitions[4] * num_partitions[3])) % num_partitions[2]) * - slice_shape[2]; - subscript[1] = - ((index / (num_partitions[4] * num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = (index / (num_partitions[4] * num_partitions[3] * - num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[5] = (index % num_partitions[5]) * slice_shape[5]; - subscript[4] = - ((index / num_partitions[5]) % num_partitions[4]) * slice_shape[4]; - subscript[3] = - ((index / (num_partitions[5] * num_partitions[4])) % num_partitions[3]) * - slice_shape[3]; - subscript[2] = - ((index / (num_partitions[5] * num_partitions[4] * num_partitions[3])) % - num_partitions[2]) * - slice_shape[2]; - subscript[1] = ((index / (num_partitions[5] * num_partitions[4] * - num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[5] * num_partitions[4] * num_partitions[3] * - num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[6] = (index % num_partitions[6]) * slice_shape[6]; - subscript[5] = - ((index / num_partitions[6]) % num_partitions[5]) * slice_shape[5]; - subscript[4] = - ((index / (num_partitions[6] * num_partitions[5])) % num_partitions[4]) * - slice_shape[4]; - subscript[3] = - ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4])) % - num_partitions[3]) * - slice_shape[3]; - subscript[2] = ((index / (num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3])) % - num_partitions[2]) * - slice_shape[2]; - subscript[1] = - ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * - num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * - num_partitions[3] * num_partitions[2] * num_partitions[1])) * - slice_shape[0]; - return subscript; -} - -template <> -Eigen::DSizes GetSliceIndices( - absl::Span num_partitions, - const Eigen::DSizes& slice_shape, const int index) { - Eigen::DSizes subscript; - subscript[7] = (index % num_partitions[7]) * slice_shape[7]; - subscript[6] = - ((index / num_partitions[7]) % num_partitions[6]) * slice_shape[6]; - subscript[5] = - ((index / (num_partitions[7] * num_partitions[6])) % num_partitions[5]) * - slice_shape[5]; - subscript[4] = - ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5])) % - num_partitions[4]) * - slice_shape[4]; - subscript[3] = ((index / (num_partitions[7] * num_partitions[6] * - num_partitions[5] * num_partitions[4])) % - num_partitions[3]) * - slice_shape[3]; - subscript[2] = - ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3])) % - num_partitions[2]) * - slice_shape[2]; - subscript[1] = - ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3] * num_partitions[2])) % - num_partitions[1]) * - slice_shape[1]; - subscript[0] = - (index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * - num_partitions[4] * num_partitions[3] * num_partitions[2] * - num_partitions[1])) * - slice_shape[0]; - return subscript; -} - constexpr absl::string_view kTensorName = "'input' tensor"; constexpr absl::string_view kResourceName = "'resource' variable tensor"; -template -Eigen::DSizes TF_ATTRIBUTE_NOINLINE -ShapeAsEigenDSizes(const TensorShape& shape); -template -Eigen::DSizes ShapeAsEigenDSizes( - const TensorShape& shape) { - return shape.AsEigenDSizes(); -} - -bool TF_ATTRIBUTE_NOINLINE -ValidateShapesForSlice(OpKernelContext* ctx, bool resource, const Tensor* input, - const std::vector& num_splits, - const std::vector& paddings); - -bool ValidateShapesForSlice(OpKernelContext* ctx, bool resource, - const Tensor* input, - const std::vector& num_splits, - const std::vector& paddings) { - const auto& ishape = input->shape(); - - Status s; - - absl::string_view input_name = resource ? kResourceName : kTensorName; - const int rank = ishape.dims(); - const auto& input_shape = ishape.dim_sizes(); - if (rank <= 0 || rank > 8) { - s = absl::InvalidArgumentError(absl::StrCat( - input_name, " must have rank in range (0, 8], but got ", rank, ".")); - } else if (rank != num_splits.size()) { - s = absl::InvalidArgumentError(absl::StrCat( - input_name, " rank must be the same as 'num_splits' length ", - num_splits.size(), ", but got rank ", rank, ".")); - } else { - for (int dim = 0; dim < rank; ++dim) { - const auto input_shape_dim = input_shape[dim]; - const auto paddings_dim = paddings[dim]; - const auto num_splits_dim = num_splits[dim]; - if ((input_shape_dim + paddings_dim) % num_splits_dim != 0) { - s = absl::InvalidArgumentError(absl::StrCat( - input_name, " shape dimension ", dim, " (", input_shape_dim, - ") with padding ", paddings_dim, - " must be evenly divisible by 'num_splits' ", num_splits_dim, ".")); - break; - } - } - } - if (!s.ok()) { - ctx->CtxFailure(__FILE__, __LINE__, s); - return false; - } - return true; -} - // Shared base class to save code space +template class XlaSplitNDShared : public OpKernel { public: explicit TF_ATTRIBUTE_NOINLINE XlaSplitNDShared(OpKernelConstruction* ctx) - : OpKernel(ctx), num_slices_(1), has_paddings_(false) { - GetAndValidateAttributes(/*split=*/true, ctx, num_splits_, num_slices_, - paddings_, has_paddings_); + : OpKernel(ctx) { + std::vector num_splits; + int num_slices = 1; + std::vector paddings; + bool has_paddings = false; + + GetAndValidateAttributes(/*split=*/true, ctx, num_splits, num_slices, + paddings, has_paddings); + + auto xla_nd_splitter = XlaNDSplitter::Create( + num_splits, num_slices, paddings, has_paddings); + OP_REQUIRES_OK(ctx, xla_nd_splitter.status()); + splitter_ = *std::move(xla_nd_splitter); } protected: - template - class SliceAndMaybePadState { - public: - int num_complete_pad_dims_; - int num_partial_pad_dims_; - TensorShape non_padded_slice_shape_; - Eigen::array, Rank> slice_paddings_; - Eigen::DSizes slice_indices_; - Eigen::DSizes output_slice_shape_dsizes_; - Eigen::DSizes non_padded_slice_shape_dsizes_; - - TF_ATTRIBUTE_NOINLINE SliceAndMaybePadState( - absl::Span num_splits, - const absl::Span input_shape, - const TensorShape& output_slice_shape, int slice_index) { - output_slice_shape_dsizes_ = ShapeAsEigenDSizes(output_slice_shape); - num_complete_pad_dims_ = 0; - num_partial_pad_dims_ = 0; - slice_indices_ = GetSliceIndices( - num_splits, output_slice_shape_dsizes_, slice_index); - - // Calculate paddings necessary for slice instead of padding input and - // slicing subsequently to reduce temporary memory allocation. - for (int dim = 0; dim < Rank; ++dim) { - const int64_t dim_size = input_shape[dim]; - const int64_t out_dim = output_slice_shape_dsizes_[dim]; - int64_t non_padded_dim = 0; - if (slice_indices_[dim] >= dim_size) { - // Complete padding. - slice_indices_[dim] = dim_size; - non_padded_dim = 0; - slice_paddings_[dim] = {0, out_dim}; - num_complete_pad_dims_++; - } else if (slice_indices_[dim] + out_dim > dim_size) { - // Partial padding. - non_padded_dim = dim_size - slice_indices_[dim]; - slice_paddings_[dim] = {0, out_dim - non_padded_dim}; - num_partial_pad_dims_++; - } else { - non_padded_dim = out_dim; - } - non_padded_slice_shape_.AddDim(non_padded_dim); - } - non_padded_slice_shape_dsizes_ = - ShapeAsEigenDSizes(non_padded_slice_shape_); - } - }; - static void TF_ATTRIBUTE_NOINLINE GetDtypeHelper(OpKernelConstruction* ctx, const char* attr_name, DataType* dtype_ptr) { OP_REQUIRES_OK(ctx, ctx->GetAttr(attr_name, dtype_ptr)); } - std::vector num_splits_; - int num_slices_; - std::vector paddings_; - bool has_paddings_; + std::optional> splitter_; }; template -class XlaSplitNDBaseOp : public XlaSplitNDShared { +class XlaSplitNDBaseOp : public XlaSplitNDShared { public: explicit XlaSplitNDBaseOp(OpKernelConstruction* ctx) - : XlaSplitNDShared(ctx) {} + : XlaSplitNDShared(ctx) {} protected: void ComputeInternal( bool resource, OpKernelContext* ctx, const std::function& assign_or_copy_value_fn, const Tensor* input) { - const int rank = input->shape().dims(); const auto& input_shape = input->shape().dim_sizes(); - if (!ValidateShapesForSlice(ctx, resource, input, num_splits_, paddings_)) { - return; - } - - TensorShape output_slice_shape; - for (int i = 0; i < rank; ++i) { - output_slice_shape.AddDim((input_shape[i] + paddings_[i]) / - ((num_slices_ == 1) ? 1 : num_splits_[i])); - } - if (num_slices_ == 1 && !has_paddings_) { - // Handle simple case first - OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(*input)); - } else { - const Device& device = ctx->eigen_device(); - std::vector output_slices(num_slices_); - for (int i = 0; i < num_slices_; i++) { - OP_REQUIRES_OK(ctx, - ctx->allocate_output( - /*index=*/i, output_slice_shape, &output_slices[i])); - } - - if (rank == 1) { - SliceAndMaybePad<1>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 2) { - SliceAndMaybePad<2>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 3) { - SliceAndMaybePad<3>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 4) { - SliceAndMaybePad<4>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 5) { - SliceAndMaybePad<5>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 6) { - SliceAndMaybePad<6>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 7) { - SliceAndMaybePad<7>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } else if (rank == 8) { - SliceAndMaybePad<8>(ctx, device, input, input_shape, output_slice_shape, - output_slices); - } - return; - } - } - - private: - void TF_ATTRIBUTE_NOINLINE SetToConstant(Tensor* output_slice, - const Device& device) { - auto output_flat = output_slice->flat(); - output_flat.device(device) = output_flat.constant(T()); - } - - template - void TF_ATTRIBUTE_NOINLINE AssignFromInput( - Tensor* output_slice, const Device& device, const Tensor* input, - const Eigen::DSizes& slice_indices, - const Eigen::DSizes& output_slice_shape_dsizes) { - output_slice->tensor().device(device) = - input->tensor().slice(slice_indices, - output_slice_shape_dsizes); - } + absl::string_view input_name = resource ? kResourceName : kTensorName; + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + return ctx->allocate_output( + /*index=*/i, output_slice_shape, tensor); + }; - template - void TF_ATTRIBUTE_NOINLINE SliceAndMaybePad( - OpKernelContext* ctx, const Device& device, const Tensor* input, - const absl::Span input_shape, - const TensorShape& output_slice_shape, - const std::vector& output_slices) { - const auto& input_tensor = input->tensor(); - // Slice shape with optional padding. - for (int i = 0; i < num_slices_; ++i) { - Tensor* output_slice = output_slices[i]; - SliceAndMaybePadState r(num_splits_, input_shape, - output_slice_shape, i); - if (r.num_complete_pad_dims_ == Rank || - (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) { - // Need to init padding - SetToConstant(output_slice, device); - } - if (r.num_complete_pad_dims_ == Rank) { - // Done - } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { - output_slice->tensor() - .slice(Eigen::DSizes(), - r.non_padded_slice_shape_dsizes_) - .device(device) = input_tensor.slice( - r.slice_indices_, r.non_padded_slice_shape_dsizes_); - } else { - AssignFromInput(output_slice, device, input, r.slice_indices_, - r.output_slice_shape_dsizes_); - } - } + const Device& device = ctx->eigen_device(); + auto status = this->splitter_->Split( + input, input_name, assign_or_copy_value_fn, allocate_output_fn, device); + OP_REQUIRES_OK(ctx, status); } }; @@ -605,7 +219,7 @@ class ReadVariableXlaSplitNDOp : public XlaSplitNDBaseOp { explicit TF_ATTRIBUTE_NOINLINE ReadVariableXlaSplitNDOp( OpKernelConstruction* ctx) : XlaSplitNDBaseOp(ctx) { - XlaSplitNDShared::GetDtypeHelper(ctx, "T", &dtype_); + XlaSplitNDShared::GetDtypeHelper(ctx, "T", &dtype_); } void Compute(OpKernelContext* ctx) override { @@ -671,12 +285,18 @@ TF_CALL_uint4(REGISTER_READ_VARIABLE_XLA_SPLIT_ND); #undef REGISTER_READ_VARIABLE_XLA_SPLIT_ND // Shared base class to save code space +template class XlaConcatNDShared : public OpKernel { public: explicit TF_ATTRIBUTE_NOINLINE XlaConcatNDShared(OpKernelConstruction* ctx) : OpKernel(ctx), num_slices_(1), has_paddings_(false) { GetAndValidateAttributes(/*split=*/false, ctx, num_concats_, num_slices_, paddings_, has_paddings_); + + auto xla_nd_concatenator = XlaNDConcatenator::Create( + num_concats_, num_slices_, paddings_, has_paddings_); + OP_REQUIRES_OK(ctx, xla_nd_concatenator.status()); + concatenator_ = *std::move(xla_nd_concatenator); } protected: @@ -714,132 +334,31 @@ class XlaConcatNDShared : public OpKernel { return absl::OkStatus(); } - void ApplyAssignOrCopyShared( - OpKernelContext* ctx, - const std::function& assign_or_copy_value_fn, - const Tensor& input) { - OP_REQUIRES_OK(ctx, assign_or_copy_value_fn(input)); - } - - template - class MaybeUnpadAndAssignState { - public: - int num_complete_pad_dims_; - int num_partial_pad_dims_; - TensorShape non_padded_slice_shape_; - Eigen::DSizes slice_shape_dsizes_; - Eigen::array, Rank> slice_paddings_; - Eigen::DSizes slice_indices_; - Eigen::DSizes output_slice_shape_dsizes_; - Eigen::DSizes non_padded_slice_shape_dsizes_; - - TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssignState( - absl::Span num_concats, const Tensor& input0, - Tensor* output, int slice_index) { - slice_shape_dsizes_ = input0.shape().AsEigenDSizes(); - slice_indices_ = - GetSliceIndices(num_concats, slice_shape_dsizes_, slice_index); - num_complete_pad_dims_ = 0; - num_partial_pad_dims_ = 0; - // Calculate paddings necessary to strip from slice. - for (int dim = 0; dim < Rank; ++dim) { - const int64_t dim_size = output->shape().dim_size(dim); - int64_t non_padded_dim = 0; - if (slice_indices_[dim] >= dim_size) { - // Complete padding. - slice_indices_[dim] = dim_size; - non_padded_dim = 0; - num_complete_pad_dims_++; - } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) { - // Partial padding. - non_padded_dim = dim_size - slice_indices_[dim]; - num_partial_pad_dims_++; - } else { - non_padded_dim = slice_shape_dsizes_[dim]; - } - non_padded_slice_shape_.AddDim(non_padded_dim); - } - non_padded_slice_shape_dsizes_ = - non_padded_slice_shape_.AsEigenDSizes(); - } - }; std::vector num_concats_; int num_slices_; std::vector paddings_; bool has_paddings_; + std::optional> concatenator_; }; template -class XlaConcatNDBaseOp : public XlaConcatNDShared { +class XlaConcatNDBaseOp : public XlaConcatNDShared { public: explicit TF_ATTRIBUTE_NOINLINE XlaConcatNDBaseOp(OpKernelConstruction* ctx) - : XlaConcatNDShared(ctx) {} + : XlaConcatNDShared(ctx) {} protected: void ComputeInternal( bool resource, OpKernelContext* ctx, const OpInputList& inputs, const std::function& assign_or_copy_value_fn, const std::function()>& get_output_fn) { - const int rank = inputs[0].shape().dims(); - - OP_REQUIRES(ctx, rank > 0 && rank <= 8, - absl::InvalidArgumentError(absl::StrCat( - "'inputs' tensors must have rank in range (0, 8], but got ", - rank, "."))); - - if (num_slices_ == 1 && !has_paddings_) { - // Simple case - ApplyAssignOrCopyShared(ctx, assign_or_copy_value_fn, inputs[0]); - return; - } - const Device& device = ctx->eigen_device(); - auto status_or_output = get_output_fn(); - OP_REQUIRES_OK(ctx, status_or_output.status()); - Tensor* output = std::move(status_or_output).value(); - - if (rank == 1) { - MaybeUnpadAndAssign<1>(ctx, device, inputs, output); - } else if (rank == 2) { - MaybeUnpadAndAssign<2>(ctx, device, inputs, output); - } else if (rank == 3) { - MaybeUnpadAndAssign<3>(ctx, device, inputs, output); - } else if (rank == 4) { - MaybeUnpadAndAssign<4>(ctx, device, inputs, output); - } else if (rank == 5) { - MaybeUnpadAndAssign<5>(ctx, device, inputs, output); - } else if (rank == 6) { - MaybeUnpadAndAssign<6>(ctx, device, inputs, output); - } else if (rank == 7) { - MaybeUnpadAndAssign<7>(ctx, device, inputs, output); - } else if (rank == 8) { - MaybeUnpadAndAssign<8>(ctx, device, inputs, output); - } - } - - private: - template - void TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssign(OpKernelContext* ctx, - const Device& device, - const OpInputList& inputs, - Tensor* output) { - for (int i = 0; i < num_slices_; ++i) { - MaybeUnpadAndAssignState r(num_concats_, inputs[0], output, i); - if (r.num_complete_pad_dims_ == Rank) { - continue; - } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { - output->tensor() - .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_) - .device(device) = inputs[i].tensor().slice( - Eigen::DSizes(), - r.non_padded_slice_shape_dsizes_); - } else { - output->tensor() - .slice(r.slice_indices_, r.slice_shape_dsizes_) - .device(device) = inputs[i].tensor(); - } - } + std::vector input_tensors(inputs.begin(), inputs.end()); + auto status = this->concatenator_->ComputeInternal( + absl::MakeSpan(input_tensors), assign_or_copy_value_fn, get_output_fn, + device); + OP_REQUIRES_OK(ctx, status); } }; diff --git a/tensorflow/core/tpu/kernels/sharding_utils.cc b/tensorflow/core/tpu/kernels/sharding_utils.cc new file mode 100644 index 00000000000000..0f4b9620b347f3 --- /dev/null +++ b/tensorflow/core/tpu/kernels/sharding_utils.cc @@ -0,0 +1,237 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/kernels/sharding_utils.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/macros.h" + +namespace tensorflow { +namespace sharding_internal { +absl::Status ValidateShapesForSlice(absl::string_view input_name, + const Tensor* input, + const std::vector& num_splits, + const std::vector& paddings) { + const auto& ishape = input->shape(); + + Status s; + + const int rank = ishape.dims(); + const auto& input_shape = ishape.dim_sizes(); + if (rank <= 0 || rank > 8) { + s = absl::InvalidArgumentError(absl::StrCat( + input_name, " must have rank in range (0, 8], but got ", rank, ".")); + } else if (rank != num_splits.size()) { + s = absl::InvalidArgumentError(absl::StrCat( + input_name, " rank must be the same as 'num_splits' length ", + num_splits.size(), ", but got rank ", rank, ".")); + } else { + for (int dim = 0; dim < rank; ++dim) { + const auto input_shape_dim = input_shape[dim]; + const auto paddings_dim = paddings[dim]; + const auto num_splits_dim = num_splits[dim]; + if ((input_shape_dim + paddings_dim) % num_splits_dim != 0) { + s = absl::InvalidArgumentError(absl::StrCat( + input_name, " shape dimension ", dim, " (", input_shape_dim, + ") with padding ", paddings_dim, + " must be evenly divisible by 'num_splits' ", num_splits_dim, ".")); + break; + } + } + } + return s; +} + +} // namespace sharding_internal + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[0] = index * slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[1] = (index % num_partitions[1]) * slice_shape[1]; + subscript[0] = (index / num_partitions[1]) * slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[2] = (index % num_partitions[2]) * slice_shape[2]; + subscript[1] = + ((index / num_partitions[2]) % num_partitions[1]) * slice_shape[1]; + subscript[0] = + (index / (num_partitions[2] * num_partitions[1])) * slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[3] = (index % num_partitions[3]) * slice_shape[3]; + subscript[2] = + ((index / num_partitions[3]) % num_partitions[2]) * slice_shape[2]; + subscript[1] = + ((index / (num_partitions[3] * num_partitions[2])) % num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[3] * num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[4] = (index % num_partitions[4]) * slice_shape[4]; + subscript[3] = + ((index / num_partitions[4]) % num_partitions[3]) * slice_shape[3]; + subscript[2] = + ((index / (num_partitions[4] * num_partitions[3])) % num_partitions[2]) * + slice_shape[2]; + subscript[1] = + ((index / (num_partitions[4] * num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = (index / (num_partitions[4] * num_partitions[3] * + num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[5] = (index % num_partitions[5]) * slice_shape[5]; + subscript[4] = + ((index / num_partitions[5]) % num_partitions[4]) * slice_shape[4]; + subscript[3] = + ((index / (num_partitions[5] * num_partitions[4])) % num_partitions[3]) * + slice_shape[3]; + subscript[2] = + ((index / (num_partitions[5] * num_partitions[4] * num_partitions[3])) % + num_partitions[2]) * + slice_shape[2]; + subscript[1] = ((index / (num_partitions[5] * num_partitions[4] * + num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[5] * num_partitions[4] * num_partitions[3] * + num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[6] = (index % num_partitions[6]) * slice_shape[6]; + subscript[5] = + ((index / num_partitions[6]) % num_partitions[5]) * slice_shape[5]; + subscript[4] = + ((index / (num_partitions[6] * num_partitions[5])) % num_partitions[4]) * + slice_shape[4]; + subscript[3] = + ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4])) % + num_partitions[3]) * + slice_shape[3]; + subscript[2] = ((index / (num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3])) % + num_partitions[2]) * + slice_shape[2]; + subscript[1] = + ((index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * + num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[6] * num_partitions[5] * num_partitions[4] * + num_partitions[3] * num_partitions[2] * num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +template <> +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, const int index) { + Eigen::DSizes subscript; + subscript[7] = (index % num_partitions[7]) * slice_shape[7]; + subscript[6] = + ((index / num_partitions[7]) % num_partitions[6]) * slice_shape[6]; + subscript[5] = + ((index / (num_partitions[7] * num_partitions[6])) % num_partitions[5]) * + slice_shape[5]; + subscript[4] = + ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5])) % + num_partitions[4]) * + slice_shape[4]; + subscript[3] = ((index / (num_partitions[7] * num_partitions[6] * + num_partitions[5] * num_partitions[4])) % + num_partitions[3]) * + slice_shape[3]; + subscript[2] = + ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3])) % + num_partitions[2]) * + slice_shape[2]; + subscript[1] = + ((index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3] * num_partitions[2])) % + num_partitions[1]) * + slice_shape[1]; + subscript[0] = + (index / (num_partitions[7] * num_partitions[6] * num_partitions[5] * + num_partitions[4] * num_partitions[3] * num_partitions[2] * + num_partitions[1])) * + slice_shape[0]; + return subscript; +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sharding_utils.h b/tensorflow/core/tpu/kernels/sharding_utils.h new file mode 100644 index 00000000000000..429e327462ad74 --- /dev/null +++ b/tensorflow/core/tpu/kernels/sharding_utils.h @@ -0,0 +1,456 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace sharding_internal { +absl::Status ValidateShapesForSlice(absl::string_view input_name, + const Tensor* input, + const std::vector& num_splits, + const std::vector& paddings); +template +Eigen::DSizes TF_ATTRIBUTE_NOINLINE +ShapeAsEigenDSizes(const TensorShape& shape); +template +Eigen::DSizes ShapeAsEigenDSizes( + const TensorShape& shape) { + return shape.AsEigenDSizes(); +} + +} // namespace sharding_internal + +// Converts flatten index to start indices (subscript scaled with slice shape) +// for determining where to start a slice in the input tensor. +template +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); + +template +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, + const int index) { + return Eigen::DSizes(); +} + +// Shared base class to save code space +template +class XlaNDSplitter { + public: + static absl::StatusOr> Create( + const std::vector& num_splits, int num_slices, + const std::vector& paddings, bool has_paddings) { + if (num_splits.size() != paddings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_splits size ", num_splits.size(), + " mismatch with paddings size ", paddings.size(), ".")); + } + + int splits_cnt = 1; + for (auto split : num_splits) { + splits_cnt *= split; + } + + if (num_slices != splits_cnt) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect num_slices ", splits_cnt, " but got ", num_slices)); + } + + return XlaNDSplitter(num_splits, num_slices, paddings, + has_paddings); + } + + // Split the given input. + // + // The splitted outputs are stored into tensors allocated by + // `allocate_output_fn`. In the simple case of pass through (no split and no + // padding), the output is stored through the fast path by + // `assign_or_copy_value_fn`. + absl::Status Split( + const Tensor* input, absl::string_view input_name, + const std::function& assign_or_copy_value_fn, + const std::function& allocate_output_fn, + const Device& device) { + if (num_splits_.size() != paddings_.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_splits size ", num_splits_.size(), + " mismatch with paddings size ", paddings_.size(), ".")); + } + + const int rank = input->shape().dims(); + const auto& input_shape = input->shape().dim_sizes(); + + TF_RETURN_IF_ERROR(sharding_internal::ValidateShapesForSlice( + input_name, input, num_splits_, paddings_)); + + TensorShape output_slice_shape; + for (int i = 0; i < rank; ++i) { + output_slice_shape.AddDim((input_shape[i] + paddings_[i]) / + ((num_slices_ == 1) ? 1 : num_splits_[i])); + } + if (num_slices_ == 1 && !has_paddings_) { + // Handle simple case first + TF_RETURN_IF_ERROR(assign_or_copy_value_fn(*input)); + } else { + std::vector output_slices(num_slices_); + for (int i = 0; i < num_slices_; i++) { + TF_RETURN_IF_ERROR(allocate_output_fn( + /*index=*/i, output_slice_shape, &output_slices[i])); + } + + if (rank == 1) { + SliceAndMaybePad<1>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 2) { + SliceAndMaybePad<2>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 3) { + SliceAndMaybePad<3>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 4) { + SliceAndMaybePad<4>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 5) { + SliceAndMaybePad<5>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 6) { + SliceAndMaybePad<6>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 7) { + SliceAndMaybePad<7>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 8) { + SliceAndMaybePad<8>(device, input, input_shape, output_slice_shape, + output_slices); + } + } + return absl::OkStatus(); + } + + private: + template + class SliceAndMaybePadState { + public: + int num_complete_pad_dims_; + int num_partial_pad_dims_; + TensorShape non_padded_slice_shape_; + Eigen::array, Rank> slice_paddings_; + Eigen::DSizes slice_indices_; + Eigen::DSizes output_slice_shape_dsizes_; + Eigen::DSizes non_padded_slice_shape_dsizes_; + + TF_ATTRIBUTE_NOINLINE SliceAndMaybePadState( + absl::Span num_splits, + const absl::Span input_shape, + const TensorShape& output_slice_shape, int slice_index) { + output_slice_shape_dsizes_ = + sharding_internal::ShapeAsEigenDSizes(output_slice_shape); + num_complete_pad_dims_ = 0; + num_partial_pad_dims_ = 0; + slice_indices_ = GetSliceIndices( + num_splits, output_slice_shape_dsizes_, slice_index); + + // Calculate paddings necessary for slice instead of padding input and + // slicing subsequently to reduce temporary memory allocation. + for (int dim = 0; dim < Rank; ++dim) { + const int64_t dim_size = input_shape[dim]; + const int64_t out_dim = output_slice_shape_dsizes_[dim]; + int64_t non_padded_dim = 0; + if (slice_indices_[dim] >= dim_size) { + // Complete padding. + slice_indices_[dim] = dim_size; + non_padded_dim = 0; + slice_paddings_[dim] = {0, out_dim}; + num_complete_pad_dims_++; + } else if (slice_indices_[dim] + out_dim > dim_size) { + // Partial padding. + non_padded_dim = dim_size - slice_indices_[dim]; + slice_paddings_[dim] = {0, out_dim - non_padded_dim}; + num_partial_pad_dims_++; + } else { + non_padded_dim = out_dim; + } + non_padded_slice_shape_.AddDim(non_padded_dim); + } + non_padded_slice_shape_dsizes_ = + sharding_internal::ShapeAsEigenDSizes(non_padded_slice_shape_); + } + }; + + std::vector num_splits_; + int num_slices_; + std::vector paddings_; + bool has_paddings_; + + explicit XlaNDSplitter(const std::vector& num_splits, int num_slices, + const std::vector& paddings, + bool has_paddings) + : num_splits_(num_splits), + num_slices_(num_slices), + paddings_(paddings), + has_paddings_(has_paddings) {} + + void TF_ATTRIBUTE_NOINLINE SetToConstant(Tensor* output_slice, + const Device& device) { + auto output_flat = output_slice->flat(); + output_flat.device(device) = output_flat.constant(T()); + } + + template + void TF_ATTRIBUTE_NOINLINE AssignFromInput( + Tensor* output_slice, const Device& device, const Tensor* input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& output_slice_shape_dsizes) { + output_slice->tensor().device(device) = + input->tensor().slice(slice_indices, + output_slice_shape_dsizes); + } + + template + void TF_ATTRIBUTE_NOINLINE + SliceAndMaybePad(const Device& device, const Tensor* input, + const absl::Span input_shape, + const TensorShape& output_slice_shape, + const std::vector& output_slices) { + const auto& input_tensor = input->tensor(); + // Slice shape with optional padding. + for (int i = 0; i < num_slices_; ++i) { + Tensor* output_slice = output_slices[i]; + SliceAndMaybePadState r(num_splits_, input_shape, + output_slice_shape, i); + if (r.num_complete_pad_dims_ == Rank || + (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) { + // Need to init padding + SetToConstant(output_slice, device); + } + if (r.num_complete_pad_dims_ == Rank) { + // Done + } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { + output_slice->tensor() + .slice(Eigen::DSizes(), + r.non_padded_slice_shape_dsizes_) + .device(device) = input_tensor.slice( + r.slice_indices_, r.non_padded_slice_shape_dsizes_); + } else { + AssignFromInput(output_slice, device, input, r.slice_indices_, + r.output_slice_shape_dsizes_); + } + } + } +}; + +// Shared base class to save code space +template +class XlaNDConcatenator { + public: + static absl::StatusOr> Create( + const std::vector& num_concats, int num_slices, + const std::vector& paddings, bool has_paddings) { + if (num_concats.size() != paddings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_concats size ", num_concats.size(), + " mismatch with paddings size ", paddings.size(), ".")); + } + + int concats_cnt = 1; + for (auto concat : num_concats) { + concats_cnt *= concat; + } + + if (num_slices != concats_cnt) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect num_slices ", concats_cnt, " but got ", num_slices)); + } + + return XlaNDConcatenator(num_concats, num_slices, paddings, + has_paddings); + } + absl::Status ComputeInternal( + absl::Span inputs, + const std::function& assign_or_copy_value_fn, + const std::function()>& get_output_fn, + const Device& device) { + const int rank = inputs[0].shape().dims(); + + if (rank < 1 || rank > 8) { + return absl::InvalidArgumentError(absl::StrCat( + "'inputs' tensors must have rank in range (0, 8], but got ", rank, + ".")); + } + + if (num_slices_ == 1 && !has_paddings_) { + // Simple case + return assign_or_copy_value_fn(inputs[0]); + } + + TF_ASSIGN_OR_RETURN(Tensor * output, get_output_fn()); + + if (rank == 1) { + MaybeUnpadAndAssign<1>(device, inputs, output); + } else if (rank == 2) { + MaybeUnpadAndAssign<2>(device, inputs, output); + } else if (rank == 3) { + MaybeUnpadAndAssign<3>(device, inputs, output); + } else if (rank == 4) { + MaybeUnpadAndAssign<4>(device, inputs, output); + } else if (rank == 5) { + MaybeUnpadAndAssign<5>(device, inputs, output); + } else if (rank == 6) { + MaybeUnpadAndAssign<6>(device, inputs, output); + } else if (rank == 7) { + MaybeUnpadAndAssign<7>(device, inputs, output); + } else if (rank == 8) { + MaybeUnpadAndAssign<8>(device, inputs, output); + } + return absl::OkStatus(); + } + + private: + template + class MaybeUnpadAndAssignState { + public: + int num_complete_pad_dims_; + int num_partial_pad_dims_; + TensorShape non_padded_slice_shape_; + Eigen::DSizes slice_shape_dsizes_; + Eigen::array, Rank> slice_paddings_; + Eigen::DSizes slice_indices_; + Eigen::DSizes output_slice_shape_dsizes_; + Eigen::DSizes non_padded_slice_shape_dsizes_; + + TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssignState( + absl::Span num_concats, const Tensor& input0, + Tensor* output, int slice_index) { + slice_shape_dsizes_ = input0.shape().AsEigenDSizes(); + slice_indices_ = + GetSliceIndices(num_concats, slice_shape_dsizes_, slice_index); + num_complete_pad_dims_ = 0; + num_partial_pad_dims_ = 0; + // Calculate paddings necessary to strip from slice. + for (int dim = 0; dim < Rank; ++dim) { + const int64_t dim_size = output->shape().dim_size(dim); + int64_t non_padded_dim = 0; + if (slice_indices_[dim] >= dim_size) { + // Complete padding. + slice_indices_[dim] = dim_size; + non_padded_dim = 0; + num_complete_pad_dims_++; + } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) { + // Partial padding. + non_padded_dim = dim_size - slice_indices_[dim]; + num_partial_pad_dims_++; + } else { + non_padded_dim = slice_shape_dsizes_[dim]; + } + non_padded_slice_shape_.AddDim(non_padded_dim); + } + non_padded_slice_shape_dsizes_ = + non_padded_slice_shape_.AsEigenDSizes(); + } + }; + + std::vector num_concats_; + int num_slices_; + std::vector paddings_; + bool has_paddings_; + + explicit TF_ATTRIBUTE_NOINLINE XlaNDConcatenator( + const std::vector& num_concats, int num_slices, + const std::vector& paddings, bool has_paddings) + : num_concats_(num_concats), + num_slices_(num_slices), + paddings_(paddings), + has_paddings_(has_paddings) {} + + template + void TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssign(const Device& device, + absl::Span inputs, + Tensor* output) { + for (int i = 0; i < num_slices_; ++i) { + MaybeUnpadAndAssignState r(num_concats_, inputs[0], output, i); + if (r.num_complete_pad_dims_ == Rank) { + continue; + } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { + output->tensor() + .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_) + .device(device) = inputs[i].tensor().slice( + Eigen::DSizes(), + r.non_padded_slice_shape_dsizes_); + } else { + output->tensor() + .slice(r.slice_indices_, r.slice_shape_dsizes_) + .device(device) = inputs[i].tensor(); + } + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ diff --git a/tensorflow/core/tpu/kernels/sharding_utils_test.cc b/tensorflow/core/tpu/kernels/sharding_utils_test.cc new file mode 100644 index 00000000000000..cd583df8a57bef --- /dev/null +++ b/tensorflow/core/tpu/kernels/sharding_utils_test.cc @@ -0,0 +1,456 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/tpu/kernels/sharding_utils.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace { +Eigen::ThreadPoolDevice CreateThreadPoolDevice() { + constexpr int kMaxParallelism = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "Resharding", kMaxParallelism); + + Eigen::ThreadPoolDevice device(thread_pool->AsEigenThreadPool(), + kMaxParallelism); + return device; +} + +TEST(XlaNDSplitterTest, NoSplits) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2, 2}); + const std::vector num_splits = {1, 1, 1}; + const std::vector paddings(num_splits.size(), 0); + const int num_outputs = 1; + auto input_tensor = + test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/false))); + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, + TensorShape({2, 2, 2}))); +} + +TEST(XlaNDSplitterTest, NoSplitsWithPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 1, 1}); + const std::vector num_splits = {1, 1, 1}; + const std::vector paddings = {0, 1, 1}; + const int num_outputs = 1; + auto input_tensor = test::AsTensor({0, 1}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + std::vector expected_values(3 * 3 * 3); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 0, 0, 0, 1, 0, 0, 0}, + TensorShape({2, 2, 2}))); +} + +TEST(XlaNDSplitterTest, SplitNoPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({4, 4}); + const std::vector num_splits = {2, 2}; + const std::vector paddings(num_splits.size(), 0); + const int num_outputs = 4; + auto input_tensor = test::AsTensor( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), num_outputs); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 1, 4, 5}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[1], + test::AsTensor({2, 3, 6, 7}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[2], + test::AsTensor({8, 9, 12, 13}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[3], + test::AsTensor({10, 11, 14, 15}, TensorShape({2, 2}))); +} + +TEST(XlaNDSplitterTest, SplitPartialPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({3, 3}); + const std::vector num_splits = {2, 2}; + const std::vector paddings = {1, 1}; + const int num_outputs = 4; + auto input_tensor = + test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7, 8}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), num_outputs); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 1, 3, 4}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[1], + test::AsTensor({2, 0, 5, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[2], + test::AsTensor({6, 7, 0, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[3], + test::AsTensor({8, 0, 0, 0}, TensorShape({2, 2}))); +} + +TEST(XlaNDSplitterTest, SplitCompletePadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 1}); + const std::vector num_splits = {2, 2}; + const std::vector paddings = {2, 3}; + const int num_outputs = 4; + auto input_tensor = test::AsTensor({0, 1}, input_shape); + + std::vector output_tensors; + output_tensors.resize(num_outputs); + auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, + Tensor** tensor) { + if (i < 0 || i >= output_tensors.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Index ", i, " out of range [0, ", output_tensors.size(), "]")); + } + output_tensors[i] = Tensor(tensorflow::DT_INT32, output_slice_shape); + *tensor = &output_tensors[i]; + return absl::OkStatus(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors[0] = input; + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto splitter, (XlaNDSplitter::Create( + num_splits, num_outputs, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(splitter.Split(&input_tensor, "test", assign_or_copy_value_fn, + allocate_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), num_outputs); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 0, 1, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[1], + test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[2], + test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); + test::ExpectTensorEqual( + output_tensors[3], + test::AsTensor({0, 0, 0, 0}, TensorShape({2, 2}))); +} + +TEST(XlaNDConcatenatorTest, NoConcats) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2, 2}); + const TensorShape output_shape({2, 2, 2}); + const std::vector num_concats = {1, 1, 1}; + const std::vector paddings(num_concats.size(), 0); + int num_slices = 1; + auto tensor0 = test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 2, 3, 4, 5, 6, 7}, + TensorShape({2, 2, 2}))); +} + +TEST(XlaNDConcatenatorTest, ConcatNoPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2}); + const TensorShape output_shape({4, 4}); + const std::vector num_concats = {2, 2}; + const std::vector paddings(num_concats.size(), 0); + int num_slices = 4; + auto tensor0 = test::AsTensor({0, 1, 2, 3}, input_shape); + auto tensor1 = test::AsTensor({4, 5, 6, 7}, input_shape); + auto tensor2 = test::AsTensor({8, 9, 10, 11}, input_shape); + auto tensor3 = test::AsTensor({12, 13, 14, 15}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + input_tensors.push_back(tensor1); + input_tensors.push_back(tensor2); + input_tensors.push_back(tensor3); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 4, 5, 2, 3, 6, 7, 8, 9, + 12, 13, 10, 11, 14, 15}, + TensorShape({4, 4}))); +} + +TEST(XlaNDConcatenatorTest, ConcatPartialPadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2}); + const TensorShape output_shape({3, 3}); + const std::vector num_concats = {2, 2}; + const std::vector paddings = {1, 1}; + int num_slices = 4; + auto tensor0 = test::AsTensor({0, 1, 2, 3}, input_shape); + auto tensor1 = test::AsTensor({4, 5, 6, 7}, input_shape); + auto tensor2 = test::AsTensor({8, 9, 10, 11}, input_shape); + auto tensor3 = test::AsTensor({12, 13, 14, 15}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + input_tensors.push_back(tensor1); + input_tensors.push_back(tensor2); + input_tensors.push_back(tensor3); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], test::AsTensor({0, 1, 4, 2, 3, 6, 8, 9, 12}, + TensorShape({3, 3}))); +} + +TEST(XlaNDConcatenatorTest, ConcatCompletePadding) { + auto device = CreateThreadPoolDevice(); + + const TensorShape input_shape({2, 2}); + const TensorShape output_shape({2, 2}); + const std::vector num_concats = {2, 2}; + const std::vector paddings = {2, 2}; + int num_slices = 4; + auto tensor0 = test::AsTensor({0, 1, 2, 3}, input_shape); + auto tensor1 = test::AsTensor({4, 5, 6, 7}, input_shape); + auto tensor2 = test::AsTensor({8, 9, 10, 11}, input_shape); + auto tensor3 = test::AsTensor({12, 13, 14, 15}, input_shape); + std::vector input_tensors; + input_tensors.push_back(tensor0); + input_tensors.push_back(tensor1); + input_tensors.push_back(tensor2); + input_tensors.push_back(tensor3); + + std::vector output_tensors; + output_tensors.reserve(1); + auto get_output_fn = [&]() { + output_tensors.push_back(Tensor(tensorflow::DT_INT32, output_shape)); + return &output_tensors.back(); + }; + auto assign_or_copy_value_fn = [&](const Tensor& input) -> Status { + output_tensors.push_back(input); + return absl::OkStatus(); + }; + + TF_ASSERT_OK_AND_ASSIGN( + auto concatenator, + (XlaNDConcatenator::Create( + num_concats, num_slices, paddings, + /*has_paddings=*/true))); + + TF_ASSERT_OK(concatenator.ComputeInternal(absl::MakeSpan(input_tensors), + assign_or_copy_value_fn, + get_output_fn, device)); + + ASSERT_EQ(output_tensors.size(), 1); + test::ExpectTensorEqual( + output_tensors[0], + test::AsTensor({0, 1, 2, 3}, TensorShape({2, 2}))); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc index 54feed5d5fffe2..53a0c70779534d 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc @@ -357,6 +357,8 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { const int64* splits_tensor_ptr = splits->flat().data(); const int32* id_counts_tensor_ptr = id_counts->flat().data(); + const int32_t total_id_count = row_ids->NumElements(); + const int num_physical_replica = num_replica_ * num_sc_per_chip_; size_t xla_pad_size = stream_executor::tpu::OpsApiFn() @@ -405,6 +407,12 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { const int32 max_ids_per_chip = max_ids_per_chip_per_sample_ * sample_count_; + OP_REQUIRES( + ctx, max_ids_per_chip % xla_pad_size == 0, + absl::InvalidArgumentError(absl::StrCat( + "The max_ids_per_chip is set to be ", max_ids_per_chip, + " which is not divisible by the xla_pad_size ", xla_pad_size, " ."))); + const int32 padded_row_pointers_size_per_sc = xla::RoundUpTo(num_physical_replica, xla_pad_size); @@ -435,6 +443,11 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { sorted_token_ids_tensor->flat().data(); float* sorted_gains_tensor_ptr = sorted_gains_tensor->flat().data(); + // This packed id count is used to track how many ids we have packed into + // the output tensor and based on this we would know how many ids that we + // dropped. + int32_t packed_id_count = 0; + int32 global_index = 0; int32 row_pointers_index = 0; for (int sc_id = 0; sc_id < num_sc_per_chip_; ++sc_id) { @@ -453,14 +466,41 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { const int token_id_start_pos = *(id_counts_tensor_ptr + start_division_pos); - std::copy_n(col_ids_tensor_ptr + token_id_start_pos, token_id_count, - sorted_token_ids_tensor_ptr + global_index); - std::copy_n(row_ids_tensor_ptr + token_id_start_pos, token_id_count, - sorted_sample_ids_tensor_ptr + global_index); - std::copy_n(gains_tensor_ptr + token_id_start_pos, token_id_count, - sorted_gains_tensor_ptr + global_index); - - global_index += token_id_count; + if (global_index + token_id_count > max_ids_per_chip) { + if (allow_id_dropping_for_minibatching_) { + const int32_t copy_id_count = + std::min(max_ids_per_chip - global_index, token_id_count); + std::copy_n(col_ids_tensor_ptr + token_id_start_pos, copy_id_count, + sorted_token_ids_tensor_ptr + global_index); + std::copy_n(row_ids_tensor_ptr + token_id_start_pos, copy_id_count, + sorted_sample_ids_tensor_ptr + global_index); + std::copy_n(gains_tensor_ptr + token_id_start_pos, copy_id_count, + sorted_gains_tensor_ptr + global_index); + packed_id_count += copy_id_count; + global_index = max_ids_per_chip; + } else { + const int32_t remain_id_count = total_id_count - packed_id_count; + ctx->CtxFailure(absl::InvalidArgumentError(absl::StrCat( + "The max_ids_per_chip is set to be ", max_ids_per_chip, + " which is not going to fit all ids. The remaining id count " + "is ", + remain_id_count, + " . Please consider setting the " + "sparse_core_allow_id_dropping_for_minibatching to be " + "true. "))); + return; + } + } else { + std::copy_n(col_ids_tensor_ptr + token_id_start_pos, token_id_count, + sorted_token_ids_tensor_ptr + global_index); + std::copy_n(row_ids_tensor_ptr + token_id_start_pos, token_id_count, + sorted_sample_ids_tensor_ptr + global_index); + std::copy_n(gains_tensor_ptr + token_id_start_pos, token_id_count, + sorted_gains_tensor_ptr + global_index); + + global_index += token_id_count; + packed_id_count += token_id_count; + } *(row_pointers_tensor_ptr + row_pointers_index) = global_index; int32 num_ids_to_pad_per_replica = @@ -484,13 +524,16 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { } } - int32 ids_unpadded_size = global_index; + int32_t ids_unpadded_size = global_index; - OP_REQUIRES(ctx, ids_unpadded_size <= max_ids_per_chip, - absl::InvalidArgumentError(absl::StrCat( - "Got ", ids_unpadded_size, - " ids after padding but the max_ids_per_chip is set to be ", - max_ids_per_chip, " which is smaller."))); + if (packed_id_count < total_id_count) { + const int32_t dropped_id_count = total_id_count - packed_id_count; + LOG(WARNING) << "Dropping " << dropped_id_count + << " ids so that the produced CsrWrappedCooTensor can be fit " + "in static bound of " + << max_ids_per_chip + << " . This could potentially impact the model quality."; + } int32 row_pointers_unpadded_size = total_num_minibatch * padded_row_pointers_size_per_sc; @@ -923,7 +966,8 @@ void GetMinibatchSplitsWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { table_name_); CalculateHeadroom(this_max_ids, this_max_uniques, program_key, - max_ids_per_partition, max_unique_ids_per_partition); + max_ids_per_partition, max_unique_ids_per_partition, + dropped_id_count); Tensor* splits_tensor; OP_REQUIRES_OK( diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h index b61367d2cb0796..f2d35b3fa76cd6 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -73,6 +73,8 @@ class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel { std::string table_name_; std::unique_ptr sparse_core_ops_stats_handler_; + bool allow_id_dropping_for_minibatching_ = false; + private: int num_replica_ = 1; int max_minibatches_per_sc_ = 1; @@ -96,7 +98,8 @@ class GetMinibatchSplitsWithPhysicalReplicaOp : public OpKernel { virtual void CalculateHeadroom(int32 this_max_ids, int32 this_max_uniques, tstring program_key, int64_t max_ids_per_partition, - int64_t max_unique_ids_per_partition) {} + int64_t max_unique_ids_per_partition, + int32_t dropped_id_count) {} virtual inline int32_t CalculateBucketIdWithHashing(int32_t col_id, int32_t num_buckets) { // TODO(pineapplejuice233): Add a proper hashing function here. diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc index 81d56802f1c77b..3d5e8642b2e58e 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc @@ -33,7 +33,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/tpu/c_api_decl.h" -#include "xla/stream_executor/tpu/status_helper.h" #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/xla_data.pb.h" @@ -41,12 +40,21 @@ limitations under the License. #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h" #include "tsl/platform/macros.h" +typedef tensorflow::monitoring::Gauge TFGaugeMetric; +static TFGaugeMetric* max_ids_per_partition_gauge_ = TFGaugeMetric::New( + "/tensorflow/tpu/embedding/maximum_ids_per_partition", + "Max ids_per_partition limit for each table", "device", "table"); +static TFGaugeMetric* max_unique_ids_per_partition_gauge_ = TFGaugeMetric::New( + "/tensorflow/tpu/embedding/maximum_unique_ids_per_partition", + "Max unique_ids_per_partition limit for each table", "device", "table"); + namespace tensorflow { namespace { @@ -216,6 +224,7 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { quantization_config_high_ = quant_clipping_float; } } + device_name_ = ctx->device()->name(); // Check for incomplete quantization config. OP_REQUIRES(ctx, quantization_config_low_.has_value() == @@ -248,10 +257,17 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { ctx, GetMaxIdsAndUniquesExternal( "", table_name_, per_sparse_core_batch_size, feature_width, &max_ids_per_partition, &max_unique_ids_per_partition)); - VLOG(3) << "XlaSparseDenseMatmulWithCsrInputOp: " - << "table_name = '" << table_name_ - << "', max_ids = " << max_ids_per_partition - << ", max_uniques = " << max_unique_ids_per_partition; + // Log max_ids and max_uniques for offline analysis. We do this here since + // these values are fixed at TPU compile time and remain fixed during + // training. + max_ids_per_partition_gauge_->GetCell(device_name_, table_name_) + ->Set(max_ids_per_partition); + max_unique_ids_per_partition_gauge_->GetCell(device_name_, table_name_) + ->Set(max_unique_ids_per_partition); + LOG(INFO) << "Lowering XlaSparseDenseMatmulWithCsrInputOp to HLO: " + << "table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape( "num_minibatches_per_physical_sparse_core")), @@ -321,6 +337,7 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { std::optional quantization_config_low_; std::optional quantization_config_high_; std::optional quantization_config_num_buckets_; + std::string device_name_; std::string table_name_; XlaSparseDenseMatmulWithCsrInputOp( @@ -410,10 +427,10 @@ class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { ctx, GetMaxIdsAndUniquesExternal( "", table_name_, per_sparse_core_batch_size, feature_width, &max_ids_per_partition, &max_unique_ids_per_partition)); - VLOG(3) << "XlaSparseDenseMatmulWithCsrInputOp: " - << "table_name = '" << table_name_ - << "', max_ids = " << max_ids_per_partition - << ", max_uniques = " << max_unique_ids_per_partition; + LOG(INFO) << "Lowering XlaSparseDenseMatmulGradWithCsrInputOp to HLO: " + << "table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; xla::XlaComputation optimizer = build_optimizer_computation(feature_width); diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index d098abe6e1ae08..5cb7e5a5d55511 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -62,7 +62,7 @@ using GuaranteedConsts = std::variant, // List of parameters for lowering function library definition to HLO IR. struct FunctionToHloArgs { const NameAttrList* const function; - const FunctionLibraryDefinition* const flib_def; + const FunctionLibraryDefinition* flib_def; int graph_def_version; GuaranteedConsts guaranteed_constants; }; diff --git a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc index 52b83ae6be78b2..97fe019201e4cb 100644 --- a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc @@ -25,7 +25,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/xla_builder.h" +#include "xla/layout_util.h" #include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" @@ -252,7 +254,11 @@ class SendTPUEmbeddingGradientsOp : public XlaOpKernel { auto builder = ctx->builder(); gradient_shapes.reserve(gradients.size()); for (xla::XlaOp op : gradients) { - gradient_shapes.push_back(builder->GetShape(op).value()); + // Gradient layout information is added by XLA, so we can just create + // default layout information. + xla::Shape gradient_shape = builder->GetShape(op).value(); + xla::LayoutUtil::SetToDefaultLayout(&gradient_shape); + gradient_shapes.push_back(gradient_shape); } std::vector learning_rates; diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc index 0b5d4444ef0c33..77f7f3361083ea 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.cc +++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc @@ -301,42 +301,6 @@ Status TpuProgramGroup::CompileAndBuild( return status.status(); } -/*static*/ -Status TpuProgramGroup::CompileAndBuild( - const xrt::XLAComputation& xrt_computation_proto, - const XLA_TpuMeshState* mesh_state, - TpuProgramGroupInterface* tpu_program_group_interface) { - se_tpu::SerializedProto serialized_compilation_request = - se_tpu::SerializeProto(xrt_computation_proto); - auto cleanup = gtl::MakeCleanup([serialized_compilation_request] { - se_tpu::SerializedProto_Free(serialized_compilation_request); - }); - size_t count = 0; - XLA_TpuProgram** xla_tpu_programs = nullptr; - StatusHelper status; - stream_executor::tpu::OpsApiFn()->TpuCompile_XrtCompileAndBuildFn( - serialized_compilation_request, mesh_state, &xla_tpu_programs, &count, - status.c_status); - if (!status.ok()) { - VLOG(1) << "Run CompileAndBuild failed."; - return status.status(); - } - - // SPMD could return 1 result for all partitions. - int num_cores_per_replica = - xrt_computation_proto.config().num_cores_per_replica() - ? xrt_computation_proto.config().num_cores_per_replica() - : 1; - TF_RET_CHECK(count == 1 || count == num_cores_per_replica); - VLOG(1) << "Initialize TpuProgramGroup."; - TpuProgramGroup* tpu_program_group = - tensorflow::down_cast(tpu_program_group_interface); - tpu_program_group->Initialize( - absl::MakeConstSpan(&xla_tpu_programs[0], count)); - stream_executor::tpu::OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs); - return status.status(); -} - std::vector TpuProgramGroup::tpu_programs( TpuProgramShardingType sharding_type) const { std::vector tpu_programs; diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h index 3d164c09725666..6859b0facd038c 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -27,7 +27,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/stream_executor/tpu/tpu_platform_interface.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" @@ -96,11 +95,6 @@ class TpuProgramGroup : public TpuProgramGroupInterface { const XLA_TpuMeshState* mesh_state, TpuProgramGroupInterface* tpu_program_group_interface); - // Compiles HLO IR and returns TPU programs ready for execution. - static Status CompileAndBuild( - const xrt::XLAComputation& xrt_computation_proto, - const XLA_TpuMeshState* mesh_state, - TpuProgramGroupInterface* tpu_program_group_interface); // Initializes `TpuProgramGroup` object with `xla_tpu_programs`. void Initialize(absl::Span xla_tpu_programs); diff --git a/tensorflow/core/tpu/ops/sparse_core_ops.cc b/tensorflow/core/tpu/ops/sparse_core_ops.cc index f9b9d64339e572..e770c1814399a2 100644 --- a/tensorflow/core/tpu/ops/sparse_core_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_ops.cc @@ -322,4 +322,12 @@ REGISTER_OP("XlaSparseCoreFtrl") return OkStatus(); }); +REGISTER_OP("GlobalIterId") + .Output("iter_id: int64") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + c->set_output(0, c->Scalar()); + return OkStatus(); + }); + } // namespace tensorflow diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index e01ba6ff6a792e..d435ac73f29780 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -37,7 +37,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = default_package_visibility, features = [ - "-layering_check", "-parse_headers", ], licenses = ["notice"], @@ -453,6 +452,7 @@ tf_mkl_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/base", ], ) @@ -726,6 +726,7 @@ tf_kernel_library( srcs = ["cuda_solvers.cc"], hdrs = ["gpu_solvers.h"], compatible_with = [], + features = ["-layering_check"], # @local_config_cuda//cuda:cusolver_static, //third_party/eigen3:blas, # and //third_party/libf2c all contain various parts of BLAS, LAPACK, # and f2c helper functions in global namespace. Tell the compiler to @@ -773,6 +774,7 @@ tf_kernel_library( "gpu_solvers.h", ], compatible_with = [], + features = ["-layering_check"], deps = [ ":cuda_solvers", "//tensorflow/core:framework", @@ -800,28 +802,6 @@ cc_library( ], ) -# For a more maintainable build this target should not exist and the headers -# should be split into the existing cc_library targets, but this change was -# automatically done so that we can remove long standing issues and complexity -# in the build system. It's up to the OWNERS of this package to get rid of it or -# not. The use of the textual_hdrs attribute is discouraged, use hdrs instead. -# Here it is used to avoid header parsing errors in packages where the feature -# parse_headers was enabled since loose headers were not being parsed. See -# go/loose-lsc-one-target-approach for more details. -cc_library( - name = "loose_headers", - tags = ["avoid_dep"], - textual_hdrs = [ - "cuda_sparse.h", - "gpu_solvers.h", - ], - visibility = [ - "//tensorflow/core/kernels:__pkg__", - "//tensorflow/core/kernels/linalg:__pkg__", - "//tensorflow/core/kernels/sparse:__pkg__", - ], -) - # Tests. tf_cc_test( name = "overflow_test", @@ -852,6 +832,7 @@ tf_cuda_only_cc_test( srcs = [ "gpu_kernel_helper_test.cu.cc", ], + features = ["-layering_check"], tags = [ "no_cuda_asan", # TODO(b/171342366): re-enable. ], @@ -890,6 +871,7 @@ tf_cc_tests( "tensor_slice_writer_test.cc", "work_sharder_test.cc", ], + features = ["-layering_check"], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index 94ccd691946c78..f4f13211ab2f8e 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -107,7 +107,10 @@ tf_proto_library( "//tensorflow/core/util/autotune_maps:conv_parameters_proto", "@local_tsl//tsl/protobuf:dnn_proto", ], - visibility = ["//waymo/ml/deploy/system/autotuning:__subpackages__"], + visibility = [ + "//waymo/ml/deploy/benchmark:__subpackages__", + "//waymo/ml/deploy/system/autotuning:__subpackages__", + ], ) # copybara:uncomment_begin(google-only) diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index 221f347c22bea2..d76d2fce3d0b03 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -43,6 +43,7 @@ TEST(CommandLineFlagsTest, BasicUsage) { bool some_switch_set_directly = false; bool some_switch_set_via_hook = true; bool some_switch_set_capitalized = false; + bool some_switch_set_by_number = false; string some_name_set_directly = "something_a"; string some_name_set_via_hook = "something_b"; float some_float_set_directly = -23.23f; @@ -55,6 +56,7 @@ TEST(CommandLineFlagsTest, BasicUsage) { "--some_switch_set_directly", "--some_switch_set_via_hook=false", "--some_switch_set_capitalized=True", + "--some_switch_set_by_number=1", "--some_name_set_directly=somethingelse", "--some_name_set_via_hook=anythingelse", "--some_float_set_directly=42.0", @@ -93,6 +95,8 @@ TEST(CommandLineFlagsTest, BasicUsage) { some_switch_set_via_hook, "some switch set via hook"), Flag("some_switch_set_capitalized", &some_switch_set_capitalized, "some switch set capitalized"), + Flag("some_switch_set_by_number", &some_switch_set_by_number, + "some switch set by number"), Flag("some_name_set_directly", &some_name_set_directly, "some name set directly"), Flag( @@ -121,6 +125,7 @@ TEST(CommandLineFlagsTest, BasicUsage) { EXPECT_EQ(true, some_switch_set_directly); EXPECT_EQ(false, some_switch_set_via_hook); EXPECT_EQ(true, some_switch_set_capitalized); + EXPECT_EQ(true, some_switch_set_by_number); EXPECT_EQ("somethingelse", some_name_set_directly); EXPECT_EQ("anythingelse", some_name_set_via_hook); EXPECT_NEAR(42.0f, some_float_set_directly, 1e-5f); diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index 81d03d14ace24e..b9ddf25fde6caf 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -2384,6 +2384,9 @@ void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, absl::flat_hash_set input_meshes; std::vector single_device_input_indices; + VLOG(4) << "DTensorOperation: " << dtensor_operation.name + << " num_inputs are " << num_inputs; + typed_inputs.resize(num_inputs); for (int j = 0; j < num_inputs; ++j) { TFE_TensorHandle* input = inputs[j]; @@ -2392,6 +2395,8 @@ void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, if (name_ != input_device) { single_device_input_indices.push_back(j); typed_inputs[j] = nullptr; + VLOG(5) << "Input " << j << ": " + << tensorflow::unwrap(input)->DebugString(); continue; } // Handle input which is on DTensor device already. @@ -2404,10 +2409,15 @@ void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, input_meshes.insert(t->layout().mesh()); } typed_inputs[j] = t; + VLOG(5) << "Input " << j << ": " << typed_inputs[j]->DebugString(); } const std::optional mesh = ChooseBroadcastingMesh(input_meshes, dtypes); + VLOG(4) << "Execution DTensorOperation: " << dtensor_operation.name + << " with broadcast mesh " + << (mesh.has_value() ? mesh->ToString() : "no broadcast mesh"); + // TODO(feyu): This short circuit only allows running unsupported op // via DTensorDevice in eager mode. for tf.function and its graph, we will // need to build single device mesh placement rules in mesh propagation. diff --git a/tensorflow/dtensor/mlir/device_mesh_cluster_coarsening.cc b/tensorflow/dtensor/mlir/device_mesh_cluster_coarsening.cc index d39a71fad4d626..c1484beb5c2e3d 100644 --- a/tensorflow/dtensor/mlir/device_mesh_cluster_coarsening.cc +++ b/tensorflow/dtensor/mlir/device_mesh_cluster_coarsening.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project @@ -165,7 +166,7 @@ GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster, // Updates the users of `merging_cluster` so that they use values // from `merged_cluster` instead. void ReplaceOperandUsagesWithMergedClusterOutputs( - const llvm::SmallVectorImpl& values_to_replace, + mlir::ValueRange values_to_replace, mlir::tf_device::ClusterOp merged_cluster) { for (auto result : llvm::zip(values_to_replace, merged_cluster.getResults())) { diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h b/tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h index fce2c014dc03b3..f54bdd248d3685 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h +++ b/tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h @@ -37,6 +37,8 @@ class MeshAttr using Base::Base; using Mesh = tensorflow::dtensor::Mesh; + static constexpr StringLiteral name = "dtensor.mesh"; + // Constructor of attribute static MeshAttr get(MLIRContext* context, const Mesh& mesh); @@ -52,6 +54,8 @@ class LayoutAttr : public Attribute::AttrBase& merged_layouts) { - for (const auto& merged_layout : merged_layouts) { + llvm::DenseMap& merged_layouts) { + for (auto& merged_layout : merged_layouts) { // merged_layout is a pair of mlir::Value and Layout. // If there is only one user of the Value and that user is a DTensorLayout // op, then we can skip creating the op as the layout is already there. Note diff --git a/tensorflow/dtensor/python/BUILD b/tensorflow/dtensor/python/BUILD index 4090792e1e97a8..bf0c9564a97bb8 100644 --- a/tensorflow/dtensor/python/BUILD +++ b/tensorflow/dtensor/python/BUILD @@ -1,7 +1,7 @@ # DTensor Python API and libraries. -load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") +load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") default_visibility = [ @@ -100,6 +100,25 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "d_random", + srcs = ["d_random.py"], + srcs_version = "PY3", + deps = [ + ":api", + ":layout", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops:shape_util", + "//tensorflow/python/ops:stateless_random_ops_gen", + ], +) + pytype_strict_library( name = "d_variable", srcs = ["d_variable.py"], diff --git a/tensorflow/dtensor/python/accelerator_util.py b/tensorflow/dtensor/python/accelerator_util.py index e32d54b36c1834..b1e96c169de4e1 100644 --- a/tensorflow/dtensor/python/accelerator_util.py +++ b/tensorflow/dtensor/python/accelerator_util.py @@ -121,6 +121,7 @@ def initialize_accelerator_system( enable_coordination_service: Optional[bool] = True, num_logical_cpu_devices: Optional[int] = None, experimental_reset_context: Optional[bool] = False, + experimental_enable_megcore: Optional[bool] = False, ) -> str: """Initializes accelerators and communication fabrics for DTensor. @@ -170,6 +171,7 @@ def initialize_accelerator_system( as an escape hatch, if there is no clear way to refactor your code to call initialize_accelerator_system() before calling TensorFlow APIs that initialize the context. + experimental_enable_megcore: Optionally enable megcore in backend. Returns: device_type: the type of accelerator that was initialized. @@ -258,7 +260,7 @@ def initialize_accelerator_system( ) if device_type == "TPU" and not config.backend_is_pw(): - tpu_util.initialize_tpu_system() + tpu_util.initialize_tpu_system(use_megacore=experimental_enable_megcore) _INITIALIZED_ACCELERATOR_SYSTEM_TYPE = device_type diff --git a/tensorflow/dtensor/python/d_random.py b/tensorflow/dtensor/python/d_random.py new file mode 100644 index 00000000000000..1697f3598151c7 --- /dev/null +++ b/tensorflow/dtensor/python/d_random.py @@ -0,0 +1,331 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""DTensor helpers for random generators.""" + +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_stateless_random_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import shape_util + +# ------------------------------------------------------------------------------ +# stateless rngs +# ------------------------------------------------------------------------------ + + +# TODO(b/171746536): switch all rng ops to official versions once supported. +def _old_tf_random_stateless_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless normal implementation that takes an layout.""" + with ops.name_scope( + name, "stateless_random_normal", [shape, seed, mean, stddev] + ) as name: + seed = ops.convert_to_tensor(seed, dtype=dtypes.int32, name="seed") + shape = shape_util.shape_tensor(shape) + mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") + stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") + rnd = api.call_with_layout( + gen_stateless_random_ops.stateless_random_normal, + layout, + shape, + seed, + dtype, + ) + result = math_ops.add(rnd * stddev, mean, name=name) + shape_util.maybe_set_static_shape(result, shape) + return result + + +def _old_tf_random_stateless_uniform( + shape, + seed, + minval=0, + maxval=None, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless uniform implementation that takes an layout.""" + dtype = dtypes.as_dtype(dtype) + accepted_dtypes = ( + dtypes.float16, + dtypes.bfloat16, + dtypes.float32, + dtypes.float64, + dtypes.int32, + dtypes.int64, + dtypes.uint32, + dtypes.uint64, + ) + if dtype not in accepted_dtypes: + raise ValueError( + f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are " + f"{accepted_dtypes}." + ) + if dtype.is_integer: + if (minval is None) != (maxval is None): + raise ValueError( + f"For integer `dtype` argument {dtype}, argument `minval` and " + f"`maxval` must be both None or not None. Got `minval`={minval} and " + f"`maxval`={maxval}." + ) + if minval is not None and dtype in (dtypes.uint32, dtypes.uint64): + raise ValueError( + f"Argument `dtype` got invalid value {dtype} when argument `minval` " + "is not None. Please don't use unsigned integers in this case." + ) + + shape = shape_util.shape_tensor(shape) + with ops.name_scope( + name, "stateless_random_uniform", [shape, seed, minval, maxval] + ) as name: + seed = ops.convert_to_tensor(seed, dtype_hint=dtypes.int32, name="seed") + + if dtype.is_integer and minval is None and maxval is None: + result = api.call_with_layout( + gen_stateless_random_ops.stateless_random_uniform_full_int, + layout, + shape, + seed=seed, + dtype=dtype, + name=name, + ) + else: + if not dtype.is_integer and maxval is None: + maxval = 1 + val_range = ops.convert_to_tensor( + maxval - minval, dtype=dtype, name="range" + ) + minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") + if dtype.is_integer: + result = api.call_with_layout( + gen_stateless_random_ops.stateless_random_uniform_int, + layout, + shape, + seed=seed, + minval=minval, + maxval=maxval, + ) + else: + rnd = api.call_with_layout( + gen_stateless_random_ops.stateless_random_uniform, + layout, + shape, + seed=seed, + dtype=dtype, + ) + result = math_ops.add(rnd * val_range, minval, name=name) + shape_util.maybe_set_static_shape(result, shape) + return result + + +def _old_tf_stateless_truncated_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless truncated normal implementation that takes an layout.""" + with ops.name_scope( + name, "stateless_truncated_normal", [shape, seed, mean, stddev] + ) as name: + seed = ops.convert_to_tensor(seed, dtype=dtypes.int32, name="seed") + shape = shape_util.shape_tensor(shape) + mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") + stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") + rnd = api.call_with_layout( + gen_stateless_random_ops.stateless_truncated_normal, + layout, + shape, + seed, + dtype, + ) + result = math_ops.add(rnd * stddev, mean, name=name) + shape_util.maybe_set_static_shape(result, shape) + return result + + +def stateless_random_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless RNG.""" + if not context.executing_eagerly(): + layout = None + + return _old_tf_random_stateless_normal( + shape, + seed=seed, + mean=mean, + stddev=stddev, + dtype=dtype, + name=name, + layout=layout, + ) + + +def stateless_random_uniform( + shape, + seed, + minval=0, + maxval=None, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless random uniform.""" + if not context.executing_eagerly(): + layout = None + + return _old_tf_random_stateless_uniform( + shape, + seed=seed, + minval=minval, + maxval=maxval, + dtype=dtype, + name=name, + layout=layout, + ) + + +def stateless_truncated_normal( + shape, + seed, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + name=None, + layout=None, +): + """DTensor stateless RNG.""" + if not context.executing_eagerly(): + layout = None + + return _old_tf_stateless_truncated_normal( + shape, + seed=seed, + mean=mean, + stddev=stddev, + dtype=dtype, + name=name, + layout=layout, + ) + + +def stateless_split(seed, num=2, mesh=None): + seed = ops.convert_to_tensor(seed) + layout = None + if mesh: + layout = layout_lib.Layout.replicated(mesh, rank=2) + return stateless_random_uniform( + shape=[num, 2], + seed=seed, + dtype=seed.dtype, + minval=None, + maxval=None, + layout=layout, + ) + + +# ------------------------------------------------------------------------------ +# stateless dropout. +# ------------------------------------------------------------------------------ + + +def _get_noise_shape(x, noise_shape): + """Noisve shape util copied from tf nn_ops.""" + # If noise_shape is none return immediately. + if noise_shape is None: + return array_ops.shape(x) + + try: + # Best effort to figure out the intended shape. + # If not possible, let the op to handle it. + # In eager mode exception will show up. + noise_shape_ = tensor_shape.as_shape(noise_shape) + except (TypeError, ValueError): + return noise_shape + + if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims): + new_dims = [] + for i, dim in enumerate(x.shape.dims): + if noise_shape_.dims[i].value is None and dim.value is not None: + new_dims.append(dim.value) + else: + new_dims.append(noise_shape_.dims[i].value) + return tensor_shape.TensorShape(new_dims) + + return noise_shape + + +# TODO(b/171213877, b/169909066): Fix layout prop in function case for the rng +# Op used. The layout prop should be able to propagate the layout from input +# tensor `x` to the tf.mul and then back propagate the layout to the +# `random_tensor`. +def dropout(x, rate, noise_shape=None, seed=None, name=None): + """DTensor replacement for dropout.""" + if not isinstance(rate, float): + raise ValueError("rate should be float for dropout.") + if seed is None: + raise ValueError("seed must be specified for DTensor dropout. Got: None") + + with ops.name_scope(name, "dropout", [x]): + x_dtype = x.dtype + keep_prob = 1 - rate + scale = 1 / keep_prob + scale = ops.convert_to_tensor(scale, dtype=x_dtype) + ret = gen_math_ops.mul(x, scale) + + noise_shape = _get_noise_shape(x, noise_shape) + # stateless_random_uniform requires a shape [2] seed. + seed = [seed, 0] + + if context.executing_eagerly(): + layout = api.fetch_layout(x) + else: + layout = None + random_tensor = _old_tf_random_stateless_uniform( + noise_shape, seed=seed, minval=0, maxval=1, dtype=x_dtype, layout=layout + ) + keep_mask = random_tensor >= rate + ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype)) + if not context.executing_eagerly(): + ret.set_shape(x.get_shape()) + return ret + + +# TODO(b/195413777): error out for stateful dropout. diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index af1cd7b54c5a74..0a7d9f90345e30 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -8,6 +8,7 @@ load( "PATHWAYS", "PATHWAYS_V3_DONUT_BACKEND", "TPU_V3_DONUT_BACKEND", + "TPU_V4_DONUT_BACKEND", "dtensor_test", ) @@ -69,6 +70,32 @@ pytype_strict_library( ], ) +py_strict_test( + name = "api_test", + srcs = [ + "api_test.py", + ], + python_version = "PY3", + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_random", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + # TODO(b/301286466): Investigate why python annotation type mismatch is not catptured by the type # strict BUILD rules. @@ -88,6 +115,64 @@ dtensor_test( ], ) +dtensor_test( + name = "batchparallel_spmd_test", + srcs = ["batchparallel_spmd_test.py"], + additional_backends = [TPU_V4_DONUT_BACKEND], + main = "batchparallel_spmd_test.py", + shard_count = { + "cpu": 4, + "gpu": 4, + "tpu": 4, + TPU_V4_DONUT_BACKEND: 8, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//tensorflow/python/ops:image_ops_gen", + "//tensorflow/python/ops:linalg_ops_gen", + "//tensorflow/python/ops:nn_impl", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "cache_test", + srcs = ["cache_test.py"], + main = "cache_test.py", + tags = [ + "nomultivm", + ], + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_variable", + "//tensorflow/dtensor/python:layout", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:combinations", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateless_random_ops_gen", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + dtensor_test( name = "config_test", srcs = ["config_test.py"], @@ -162,6 +247,35 @@ dtensor_test( ], ) +dtensor_test( + name = "conv_test", + srcs = [ + "conv_test.py", + ], + additional_backends = [TPU_V3_DONUT_BACKEND], + # All tests require 8 TPUs. + disable = ["tpu"], + shard_count = { + "cpu": 4, + "gpu": 4, + TPU_V3_DONUT_BACKEND: 4, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:special_math_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + dtensor_test( name = "device_test", srcs = ["device_test.py"], @@ -224,6 +338,52 @@ py_strict_test( ], ) +py_strict_test( + name = "multi_client_input_util_test", + timeout = "long", + srcs = ["multi_client_input_util_test.py"], + env = { + "TF2_BEHAVIOR": "1", + }, + shard_count = 8, + tags = [ + # ThreadSanitizer does not support starting new threads after multi-threaded fork. + "notsan", + "no_oss", # Fails on OSS. + "nosan", # b/195537906 + ], + deps = [ + ":multi_client_test_util", + ":test_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/dtensor/python:accelerator_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:config", + "//tensorflow/dtensor/python:input_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:mesh_util", + "//tensorflow/python/data/experimental/service:server_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:device_spec", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/lib/io:tf_record", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:check_ops", + "//tensorflow/python/ops:io_ops", + "//tensorflow/python/ops:parsing_config", + "//tensorflow/python/ops:parsing_ops", + "//tensorflow/python/ops:parsing_ops_gen", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/logging", + "@absl_py//absl/testing:parameterized", + ], +) + dtensor_test( name = "layout_test", srcs = ["layout_test.py"], @@ -678,6 +838,81 @@ dtensor_test( ], ) +dtensor_test( + name = "rng_test", + size = "medium", + srcs = ["rng_test.py"], + additional_backends = [TPU_V3_DONUT_BACKEND], + # Requires at least 8 TPUs to run the tests. + disable = ["tpu"], + disable_tfrt = [ + "gpu", + TPU_V3_DONUT_BACKEND, + ], + main = "rng_test.py", + shard_count = { + "cpu": 20, + "tpu": 10, + "gpu": 30, + TPU_V3_DONUT_BACKEND: 20, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_variable", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/distribute/cluster_resolver/tpu:tpu_cluster_resolver_py", + "//tensorflow/python/eager:remote", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:bitwise_ops_gen", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateful_random_ops_gen", + "//tensorflow/python/ops:stateless_random_ops_v2_gen", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:device_assignment", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "save_restore_v2_test", + srcs = ["save_restore_v2_test.py"], + additional_backends = [ + TPU_V3_DONUT_BACKEND, + TPU_V4_DONUT_BACKEND, + ], + main = "save_restore_v2_test.py", + shard_count = { + "cpu": 8, + "gpu": 8, + TPU_V3_DONUT_BACKEND: 8, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_variable", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/checkpoint", + "//tensorflow/python/checkpoint:checkpoint_management", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/module", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + dtensor_test( name = "variable_test", srcs = ["variable_test.py"], @@ -707,3 +942,97 @@ dtensor_test( "//third_party/py/numpy", ], ) + +dtensor_test( + name = "mnist_test", + size = "large", + srcs = ["mnist_test.py"], + shard_count = { + "tpu": 2, + }, + tags = ["nosan"], # Non-opt builds has slow XLA compilation. + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_variable", + "//tensorflow/dtensor/python:input_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "numerics_test", + srcs = ["numerics_test.py"], + additional_backends = [TPU_V3_DONUT_BACKEND], + disable = ALL_BACKENDS, + enable = [ + "tpu", + ], + deps = [ + ":test_util", + "//tensorflow/dtensor/python:accelerator_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "sparse_test", + srcs = ["sparse_test.py"], + main = "sparse_test.py", + shard_count = { + "cpu": 4, + }, + deps = [ + ":test_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +dtensor_test( + name = "tpu_device_assignment_test", + srcs = ["tpu_device_assignment_test.py"], + disable = ALL_BACKENDS, + enable = [ + "tpu", + ], + deps = [ + ":test_util", + "//tensorflow/dtensor/python:layout", + "//tensorflow/dtensor/python:numpy_util", + "//tensorflow/dtensor/python:tpu_util", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + ], +) diff --git a/tensorflow/dtensor/python/tests/api_test.py b/tensorflow/dtensor/python/tests/api_test.py new file mode 100644 index 00000000000000..7231086651439f --- /dev/null +++ b/tensorflow/dtensor/python/tests/api_test.py @@ -0,0 +1,305 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for the internal DTensor Python API.""" + +from absl.testing import parameterized +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_random +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh +_MESH_DIM_X = 'x' +_MESH_DIM_Y = 'y' + + +class APITest(test_util.DTensorBaseTest): + + def setUp(self): + super(APITest, self).setUp() + global_ids = test_util.create_device_ids_array((2, 2)) + local_device_ids = np.ravel(global_ids).tolist() + mesh_dict = { + 'CPU': Mesh( + [_MESH_DIM_X, _MESH_DIM_Y], + global_ids, + local_device_ids, + test_util.create_device_list((2, 2), 'CPU'), + ) + } + self.mesh = self.configTestMesh(mesh_dict) + self.layouts_1d = [ + Layout.replicated(self.mesh, rank=1), + Layout.batch_sharded(self.mesh, _MESH_DIM_X, rank=1), + Layout.batch_sharded(self.mesh, _MESH_DIM_Y, rank=1), + ] + self.layouts_2d = [ + Layout.replicated(self.mesh, rank=2), + Layout.batch_sharded(self.mesh, _MESH_DIM_X, rank=2), + Layout.inner_sharded(self.mesh, _MESH_DIM_X, rank=2), + Layout([_MESH_DIM_X, _MESH_DIM_Y], self.mesh), + ] + + def testV2API(self): + layout = Layout.replicated(self.mesh, rank=1) + zero_tensor = array_ops.zeros([10], layout=layout) + zero_like_tensor = array_ops.zeros_like_v2(zero_tensor, layout=layout) + self.assertAllEqual(zero_like_tensor.numpy(), zero_tensor.numpy()) + + ones_tensor = array_ops.ones([10], layout=layout) + ones_like_tensor = array_ops.ones_like_v2(zero_tensor, layout=layout) + self.assertAllEqual(ones_like_tensor.numpy(), ones_tensor.numpy()) + + def testStatelessRandom(self): + # test dtype default float32 random + result = stateless_random_ops.stateless_random_uniform( + [10], + seed=constant_op.constant([0, 0], dtype=dtypes.int64), + minval=0.0, + maxval=10.0, + ) + self.assertEqual([10], result.shape) + + # test dtype default int32 minval maxval are both None + result = stateless_random_ops.stateless_random_uniform( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int64), + dtype=dtypes.int32, + minval=None, + maxval=None, + ) + self.assertEqual([10], result.shape) + + # test maxval is None or not given + result = stateless_random_ops.stateless_random_uniform( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int64), + maxval=12, + dtype=dtypes.int32, + ) + self.assertEqual([10], result.shape) + self.assertAllInRange(result, 0, 12) + + def testStatelessRandomNormal(self): + # test dtype default float32 random + result = stateless_random_ops.stateless_random_normal( + [10], seed=constant_op.constant([0, 0], dtype=dtypes.int32) + ) + self.assertEqual([10], result.shape) + + # test dtype double + result = stateless_random_ops.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + dtype=dtypes.double, + ) + self.assertEqual([10], result.shape) + + # test mean and stddev + result = stateless_random_ops.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + mean=0, + stddev=0, + ) + self.assertEqual([10], result.shape) + self.assertAllInRange(result, 0, 0) + + # test dtensor version of each, check layouts + layout = Layout.replicated(self.mesh, rank=1) + + # test dtype default float 32 random + result = d_random.stateless_random_normal( + [10], + seed=constant_op.constant([0, 0], dtype=dtypes.int32), + layout=layout, + ) + self.assertEqual([10], result.shape) + self.assertEqual(layout, api.fetch_layout(result)) + + # test dtype double + result = d_random.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + dtype=dtypes.double, + layout=layout, + ) + self.assertEqual([10], result.shape) + self.assertEqual(layout, api.fetch_layout(result)) + + # test mean and stddev + result = d_random.stateless_random_normal( + [10], + seed=constant_op.constant([1, 2], dtype=dtypes.int32), + mean=0, + stddev=0, + layout=layout, + ) + self.assertEqual([10], result.shape) + self.assertAllInRange(result, 0, 0) + self.assertEqual(layout, api.fetch_layout(result)) + + @parameterized.named_parameters(*set( + test_util.product((('_labels_unsharded', 0), ('_labels_batch', 1), + ('_labels_inner', 2), ('_labels_both', 3)), + (('_logits_unsharded', 0), ('_logits_batch', 1), + ('_logits_inner', 2), ('_logits_both', 3))))) + def testSoftmaxCrossentropyWithLogits(self, labels_layout, logits_layout): + expected_layout = Layout.replicated(self.mesh, rank=1) + if (labels_layout == 1 or labels_layout == 3 or logits_layout == 1 or + logits_layout == 3): + expected_layout = Layout.inner_sharded(self.mesh, _MESH_DIM_X, rank=1) + + labels_layout = self.layouts_2d[labels_layout] + logits_layout = self.layouts_2d[logits_layout] + labels_numpy = np.random.uniform(size=[6, 4]) + logits_numpy = np.random.uniform(size=[6, 4]) + labels = constant_op.constant(labels_numpy, dtype=dtypes.float32) + logits = constant_op.constant(logits_numpy, dtype=dtypes.float32) + + # Should we test against the built in version or the patched version? + expected = nn_ops.softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + + labels = numpy_util.pack_numpy(labels, labels_layout) + logits = numpy_util.pack_numpy(logits, logits_layout) + dtensor_result = nn_ops.softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + self.assertDTensorEqual(expected, expected_layout, dtensor_result) + + @parameterized.named_parameters(*set( + test_util.product((('_labels_unsharded', 0), ('_labels_batch_x', 1), + ('_labels_batch_y', 2)), + (('_logits_unsharded', 0), ('_logits_batch', 1), + ('_logits_inner', 2), ('_logits_both', 3))))) + def testSparseSoftmaxCrossentropyWithLogits(self, labels_layout, + logits_layout): + expected_layout = Layout.replicated(self.mesh, rank=1) + if labels_layout == 1 or logits_layout == 1 or logits_layout == 3: + expected_layout = Layout.inner_sharded(self.mesh, _MESH_DIM_X, rank=1) + elif labels_layout == 2: + expected_layout = Layout.inner_sharded(self.mesh, _MESH_DIM_Y, rank=1) + + labels_layout = self.layouts_1d[labels_layout] + logits_layout = self.layouts_2d[logits_layout] + labels_numpy = np.random.randint(size=[6], low=0, high=4) + logits_numpy = np.random.uniform(size=[6, 4]) + labels = constant_op.constant(labels_numpy, dtype=dtypes.int64) + logits = constant_op.constant(logits_numpy, dtype=dtypes.float32) + + # Should we test against the built in version or the patched version? + expected = nn_ops.sparse_softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + + labels = numpy_util.pack_numpy(labels, labels_layout) + logits = numpy_util.pack_numpy(logits, logits_layout) + dtensor_result = nn_ops.sparse_softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits + ) + self.assertDTensorEqual(expected, expected_layout, dtensor_result) + + def test_dropout_raises_on_none_seed(self): + with api.default_mesh(self.mesh): + with self.assertRaisesRegex(ValueError, 'seed must be specified'): + _ = d_random.dropout( + array_ops.ones([2, 2], dtype=dtypes.float32), rate=0.5, seed=None + ) + + def test_default_mesh(self): + + @polymorphic_function.function + def func(a): + return a + 3.0 + + with api.default_mesh(self.mesh): + a = array_ops.zeros(shape=()) + result = func(a) + + self.assertEqual(result, 3.0) + self.assertEqual(api.fetch_layout(result).mesh, self.mesh) + self.assertTrue(api.fetch_layout(result).is_fully_replicated()) + self.assertEqual(result.device, api.device_name()) + + # Also make sure it works as wrapper + @api.default_mesh(self.mesh) + def func2(): + b = array_ops.ones(shape=()) + return func(b) + + result = func2() + self.assertEqual(result, 4.0) + self.assertEqual(api.fetch_layout(result).mesh, self.mesh) + self.assertTrue(api.fetch_layout(result).is_fully_replicated()) + self.assertEqual(result.device, api.device_name()) + + with self.assertRaisesRegex(ValueError, 'Expect `mesh` to be `Mesh`'): + with api.default_mesh(None): + pass + + def test_default_mesh_with_constant(self): + + @polymorphic_function.function + def func(): + return constant_op.constant([3, 4]) + + with api.default_mesh(self.mesh): + result = func() + + self.assertAllEqual(result, [3, 4]) + self.assertEqual(api.fetch_layout(result).mesh, self.mesh) + self.assertTrue(api.fetch_layout(result).is_fully_replicated()) + self.assertEqual(result.device, api.device_name()) + + def test_error_no_default_mesh(self): + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'No default mesh has been registered to DTensor', + ): + with ops.device_v2(api.device_name()): + _ = constant_op.constant(3.0) + + def test_get_default_mesh(self): + self.assertIsNone(api.get_default_mesh()) + with api.default_mesh(self.mesh): + self.assertEqual(api.get_default_mesh(), self.mesh) + + with api.default_mesh(self.mesh.host_mesh()): + self.assertEqual(api.get_default_mesh(), self.mesh.host_mesh()) + + self.assertEqual(api.get_default_mesh(), self.mesh) + + self.assertIsNone(api.get_default_mesh()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/batchparallel_spmd_test.py b/tensorflow/dtensor/python/tests/batchparallel_spmd_test.py new file mode 100644 index 00000000000000..b6cbbd0459b8e6 --- /dev/null +++ b/tensorflow/dtensor/python/tests/batchparallel_spmd_test.py @@ -0,0 +1,660 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for batchparallel_spmd.""" + +import itertools +from absl.testing import parameterized +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.dtensor.python.tests import test_util_ops +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test +# pylint: enable=g-direct-tensorflow-import + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +class DTensorBatchParallelSPMDTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorBatchParallelSPMDTest, self).setUp() + + self.skipForDeviceType(['TPU'], + 'all tests require 8 TPU cores.', + unless_device_count_equals_to=8) + # Builds a 8x2 mesh. + self._mesh_dim_b = 'b' + self._mesh_dim_x = 'x' + self._dims = [self._mesh_dim_b, self._mesh_dim_x] + + global_ids = test_util.create_device_ids_array((4, 2)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = { + device: Mesh( + self._dims, + global_ids, + local_ids, + test_util.create_device_list((4, 2), device), + ) + for device in ('CPU', 'GPU', 'TPU') + } + self.mesh = self.configTestMesh(mesh_dict) + context.ensure_initialized() + + # Creates a bunch of common layouts used by tests later. + # 4-d + self.replicated_layout_4d = Layout.replicated(self.mesh, rank=4) + self.batch_layout_4d = Layout.batch_sharded( + self.mesh, self._mesh_dim_b, rank=4) + + # 5-d + self.replicated_layout_5d = Layout.replicated(self.mesh, rank=5) + self.batch_layout_5d = Layout.batch_sharded( + self.mesh, self._mesh_dim_b, rank=5) + + @parameterized.named_parameters(('NoBatchDim', 0), ('SingleBatchDim', 1), + ('TwoBatchDim', 2)) + def testCholesky(self, num_batch_dim): + # Input needs to be symmetric and positive definite. + x = constant_op.constant( + [[1, 1, 1, 1], [1, 5, 5, 5], [1, 5, 14, 14], [1, 5, 14, 17]], + dtype=dtypes.float32, + ) + for _ in range(num_batch_dim): + x = array_ops.expand_dims_v2(x, 0) + s = [4] + [1 for _ in range(array_ops.rank(x) - 1)] + x = gen_array_ops.tile(x, s) + + expected_result = gen_linalg_ops.cholesky(x) + + if num_batch_dim == 0: + layout_spec = [] + elif num_batch_dim == 1: + layout_spec = [self._mesh_dim_b] + elif num_batch_dim == 2: + layout_spec = [self._mesh_dim_b, self._mesh_dim_x] + layout = Layout(layout_spec + ['unsharded'] * 2, self.mesh) + + x = numpy_util.pack_numpy(x, layout) + got = gen_linalg_ops.cholesky(input=x) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('NoBatchDim', 0), ('SingleBatchDim', 1), ('TwoBatchDim', 2)], + test_util_ops.FFT_OPS, + ) + ) + def testFFT(self, num_batch_dim, fft_op, num_nonbatch_dim): + shape = [4 for i in range(num_batch_dim + num_nonbatch_dim)] + np.random.seed(123) + x = constant_op.constant( + np.random.normal(0.0, 1.0, np.prod(shape)).reshape(shape), + dtype=dtypes.complex64, + ) + expected_result = fft_op(input=x) + + if num_batch_dim == 0: + layout_spec = [] + elif num_batch_dim == 1: + layout_spec = [self._mesh_dim_b] + elif num_batch_dim == 2: + layout_spec = [self._mesh_dim_b, self._mesh_dim_x] + layout = Layout(layout_spec + ['unsharded'] * num_nonbatch_dim, self.mesh) + + x = numpy_util.pack_numpy(x, layout) + got = fft_op(input=x) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('NoBatchDim', 0), ('SingleBatchDim', 1), ('TwoBatchDim', 2)], + test_util_ops.RFFT_OPS, + ) + ) + def testRFFT(self, num_batch_dim, rfft_op, num_nonbatch_dim, dtype): + self.skipForDeviceType(['GPU'], 'RFFT has numerical issues on GPU') + shape = [4 for i in range(num_batch_dim + num_nonbatch_dim)] + np.random.seed(123) + x = constant_op.constant( + np.random.normal(0.0, 1.0, np.prod(shape)).reshape(shape), dtype=dtype + ) + expected_result = rfft_op(input=x, fft_length=[2] * num_nonbatch_dim) + + if num_batch_dim == 0: + layout_spec = [] + elif num_batch_dim == 1: + layout_spec = [self._mesh_dim_b] + elif num_batch_dim == 2: + layout_spec = [self._mesh_dim_b, self._mesh_dim_x] + layout = Layout(layout_spec + ['unsharded'] * num_nonbatch_dim, self.mesh) + + x = numpy_util.pack_numpy(x, layout) + got = rfft_op(input=x, fft_length=[2] * num_nonbatch_dim) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('Replicated', 'replicated'), ('Sharded', 'batch')], + [ + ( + 'SamePadding', + 'SAME', + ), + ( + 'ValidPadding', + 'VALID', + ), + ], + test_util_ops.BATCH_PARALLEL_2D_WINDOW_OPS, + ) + ) + def test2DWindowOp(self, layout_spec, padding, op): + np.random.seed(123) + row_window_size = 3 + col_window_size = 4 + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size - 1, col_window_size - 1, 1] + + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_rows * num_cols * 3).reshape( + [8, num_rows, num_cols, 3]) + + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + expected_result = op(inputs, window_size, stride_size, padding) + + if layout_spec == 'replicated': + layout = self.replicated_layout_4d + else: + layout = self.batch_layout_4d + + x = numpy_util.pack_numpy(inputs, layout) + got = op(x, window_size, stride_size, padding) + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters( + test_util.product( + [('Replicated', 'replicated'), ('BatchSharded', 'batch')], + [ + ( + 'SamePadding', + 'SAME', + ), + ( + 'ValidPadding', + 'VALID', + ), + ], + test_util_ops.BATCH_PARALLEL_3D_WINDOW_OPS, + ) + ) + def test3DWindowOp(self, layout_spec, padding, op): + np.random.seed(123) + dep_window_size = 2 + row_window_size = 3 + col_window_size = 4 + window_size = [1, dep_window_size, row_window_size, col_window_size, 1] + stride_size = [ + 1, dep_window_size - 1, row_window_size - 1, col_window_size - 1, 1 + ] + + num_deps = 3 + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_deps * num_rows * num_cols * + 3).reshape([8, num_deps, num_rows, num_cols, 3]) + + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + expected_result = op(inputs, window_size, stride_size, padding) + + if layout_spec == 'replicated': + layout = self.replicated_layout_5d + else: + layout = self.batch_layout_5d + + x = numpy_util.pack_numpy(inputs, layout) + + got = op(x, window_size, stride_size, padding) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(test_util_ops.PADDINGS) + def testDepthwiseConv2dNative(self, padding): + np.random.seed(123) + x_in = np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]) + + kernel_in = np.array([ + [[[2, 0.1]], [[3, 0.2]]], + [[[0, 0.3]], [[1, 0.4]]], + ]) + + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + kernel = constant_op.constant(kernel_in, dtype=dtypes.float32) + expected_result = nn_impl.depthwise_conv2d_v2( + inputs, kernel, strides=[1, 1, 1, 1], padding=padding + ) + + layout = self.batch_layout_4d + + x = numpy_util.pack_numpy(inputs, layout) + kernel = numpy_util.pack_numpy(kernel, self.replicated_layout_4d) + got = nn_impl.depthwise_conv2d_v2( + x, kernel, strides=[1, 1, 1, 1], padding=padding + ) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testResizeBilinear(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.resize_bilinear( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.resize_bilinear( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testResizeNearestNeighbor(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.resize_nearest_neighbor( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.resize_nearest_neighbor( + images=images, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testAdjustContrastv2(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9 * 3).reshape([8, 9, 9, 3]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.adjust_contrastv2( + images=images, contrast_factor=0.5 + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.adjust_contrastv2(images=images, contrast_factor=0.5) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testAdjustSaturation(self, shard_spec): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9 * 3).reshape([8, 9, 9, 3]), + dtype=dtypes.float32, + ) + + expected_result = gen_image_ops.adjust_saturation(images=images, scale=0.5) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + images = numpy_util.pack_numpy(images, layout) + + got = gen_image_ops.adjust_saturation(images=images, scale=0.5) + + self.assertDTensorEqual(expected_result, layout, got) + + @parameterized.parameters( + itertools.permutations(['sharded', 'replicated'], 2)) + def testResizeBilinearGradBatchSharded(self, spec1, spec2): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + grads = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + expected_result = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + specs = [spec1, spec2] + layouts = [ + self.batch_layout_4d if spec == 'sharded' else self.replicated_layout_4d + for spec in specs + ] + + # Test images is replicated, grads is batch sharded + images = numpy_util.pack_numpy(images, layouts[0]) + grads = numpy_util.pack_numpy(grads, layouts[1]) + + got = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + self.assertDTensorEqual(expected_result, self.batch_layout_4d, got) + + def testResizeBilinearGradReplicated(self): + np.random.seed(123) + images = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + grads = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + expected_result = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + images = numpy_util.pack_numpy(images, self.replicated_layout_4d) + grads = numpy_util.pack_numpy(grads, self.replicated_layout_4d) + + got = gen_image_ops.resize_bilinear_grad( + grads=grads, + original_image=images, + align_corners=False, + half_pixel_centers=False, + name=None, + ) + self.assertDTensorEqual(expected_result, self.replicated_layout_4d, got) + + @parameterized.named_parameters( + test_util.product([('Replicated', 'replicated'), ('Sharded', 'batch')], [( + 'SamePadding', + 'SAME', + ), ( + 'ValidPadding', + 'VALID', + )])) + def testMaxPool3DGrad(self, shard_spec, padding): + np.random.seed(123) + dep_window_size = 2 + row_window_size = 3 + col_window_size = 4 + window_size = [1, dep_window_size, row_window_size, col_window_size, 1] + stride_size = [ + 1, dep_window_size - 1, row_window_size - 1, col_window_size - 1, 1 + ] + + num_deps = 3 + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_deps * num_rows * num_cols * + 3).reshape([8, num_deps, num_rows, num_cols, 3]) + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + + with backprop.GradientTape() as tape: + tape.watch([inputs]) + expected_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + expected_grad = tape.gradient(expected_result, [inputs]) + layout = ( + self.batch_layout_5d + if shard_spec == 'sharded' + else self.replicated_layout_5d + ) + + inputs = numpy_util.pack_numpy(inputs, layout) + + with ops.device_v2(api.device_name()): + with backprop.GradientTape() as tape: + tape.watch([inputs]) + dtensor_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + dtensor_grad = tape.gradient(dtensor_result, [inputs]) + + self.assertDTensorEqual(expected_grad[0], layout, dtensor_grad[0]) + + @parameterized.named_parameters( + test_util.product([('Replicated', 'replicated'), ('Sharded', 'batch')], [( + 'SamePadding', + 'SAME', + ), ( + 'ValidPadding', + 'VALID', + )])) + def testMaxPool3DGradGrad(self, shard_spec, padding): + np.random.seed(123) + dep_window_size = 2 + row_window_size = 3 + col_window_size = 4 + window_size = [1, dep_window_size, row_window_size, col_window_size, 1] + stride_size = [ + 1, dep_window_size - 1, row_window_size - 1, col_window_size - 1, 1 + ] + + num_deps = 3 + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_deps * num_rows * num_cols * + 3).reshape([8, num_deps, num_rows, num_cols, 3]) + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + expected_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + expected_first_grad = inner_tape.gradient(expected_result, [inputs]) + expected_second_grad = outer_tape.gradient(expected_first_grad, [inputs]) + + if shard_spec == 'sharded': + layout = self.batch_layout_5d + else: + layout = self.replicated_layout_5d + + inputs = numpy_util.pack_numpy(inputs, layout) + + @polymorphic_function.function() + def compute_gradients(inputs): + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + dtensor_result = nn_ops.max_pool3d( + inputs, window_size, stride_size, padding + ) + dtensor_first_grad = inner_tape.gradient(dtensor_result, [inputs]) + dtensor_second_grad = outer_tape.gradient(dtensor_first_grad[0], [inputs]) + return dtensor_first_grad, dtensor_second_grad + + dtensor_first_grad, dtensor_second_grad = compute_gradients(inputs) + + self.assertDTensorEqual(expected_first_grad[0], layout, + dtensor_first_grad[0]) + self.assertDTensorEqual(expected_second_grad[0], layout, + dtensor_second_grad[0]) + + @parameterized.named_parameters( + test_util.product([('Replicated', 'replicated'), ('Sharded', 'batch')], [( + 'SamePadding', + 'SAME', + ), ( + 'ValidPadding', + 'VALID', + )])) + def testMaxPoolGradGrad(self, shard_spec, padding): + np.random.seed(123) + row_window_size = 3 + col_window_size = 4 + window_size = [1, row_window_size, col_window_size, 1] + stride_size = [1, row_window_size - 1, col_window_size - 1, 1] + + num_rows = (row_window_size - 1) * 5 + 1 + num_cols = (col_window_size - 1) * 7 + 1 + x_in = np.random.normal(0.0, 1.0, 8 * num_rows * num_cols * 3).reshape( + [8, num_rows, num_cols, 3]) + inputs = constant_op.constant(x_in, dtype=dtypes.float32) + + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + expected_result = nn_ops.max_pool_v2( + inputs, window_size, stride_size, padding + ) + expected_first_grad = inner_tape.gradient(expected_result, [inputs]) + expected_second_grad = outer_tape.gradient(expected_first_grad, [inputs]) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + inputs = numpy_util.pack_numpy(inputs, layout) + + @polymorphic_function.function() + def compute_gradients(inputs): + with backprop.GradientTape() as outer_tape: + with backprop.GradientTape() as inner_tape: + outer_tape.watch([inputs]) + inner_tape.watch([inputs]) + dtensor_result = nn_ops.max_pool_v2( + inputs, window_size, stride_size, padding + ) + dtensor_first_grad = inner_tape.gradient(dtensor_result, [inputs]) + dtensor_second_grad = outer_tape.gradient(dtensor_first_grad[0], [inputs]) + return dtensor_first_grad, dtensor_second_grad + + dtensor_first_grad, dtensor_second_grad = compute_gradients(inputs) + + self.assertDTensorEqual(expected_first_grad[0], layout, + dtensor_first_grad[0]) + self.assertDTensorEqual(expected_second_grad[0], layout, + dtensor_second_grad[0]) + + @parameterized.named_parameters(('Sharded', 'sharded'), + ('Replicated', 'replicated')) + def testResizeNearestNeighborGrad(self, shard_spec): + np.random.seed(123) + grads = constant_op.constant( + np.random.normal(0.0, 1.0, 8 * 9 * 9).reshape([8, 9, 9, 1]), + dtype=dtypes.float32, + ) + expected_result = gen_image_ops.resize_nearest_neighbor_grad( + grads=grads, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + if shard_spec == 'sharded': + layout = self.batch_layout_4d + else: + layout = self.replicated_layout_4d + + grads = numpy_util.pack_numpy(grads, layout) + + got = gen_image_ops.resize_nearest_neighbor_grad( + grads=grads, + size=[3, 3], + align_corners=False, + half_pixel_centers=False, + name=None, + ) + + self.assertDTensorEqual(expected_result, layout, got) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/cache_test.py b/tensorflow/dtensor/python/tests/cache_test.py new file mode 100644 index 00000000000000..d56dacb8c8f605 --- /dev/null +++ b/tensorflow/dtensor/python/tests/cache_test.py @@ -0,0 +1,330 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests DTensor device cache for compiled function computation.""" + +import gc +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_variable +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import combinations +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_stateless_random_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + +# Convenient constants to use for tests. +_BATCH_DIM = "batch" +_MESH_DIM_X = "x" + +# Shorter notation. +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +def diff_dicts(dict1, dict2): + keys = set(dict1.keys()) | set(dict2.keys()) + return {key: dict1.get(key, 0) - dict2.get(key, 0) for key in keys} + + +class DTensorDeviceCacheTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorDeviceCacheTest, self).setUp() + device_ids = test_util.create_device_ids_array((2,)) + local_device_ids = np.ravel(device_ids).tolist() + mesh_dict = { + device: Mesh( + [_BATCH_DIM], + device_ids, + local_device_ids, + test_util.create_device_list((2,), device), + ) + for device in ("CPU", "GPU", "TPU") + } + self.mesh = self.configTestMesh(mesh_dict) + + def testBasic(self): + + @polymorphic_function.function + def func0(a): + return a + 1 + + @polymorphic_function.function + def func1(a): + return a + 2 + + c0 = api.copy_to_mesh( + constant_op.constant(1.0), Layout.replicated(self.mesh, rank=0) + ) + c1 = api.copy_to_mesh( + constant_op.constant([2.0, 3.0]), Layout.replicated(self.mesh, rank=1) + ) + c2 = api.copy_to_mesh( + constant_op.constant([4.0]), Layout.replicated(self.mesh, rank=1) + ) + c3 = api.copy_to_mesh( + constant_op.constant(1, dtype=dtypes.int32), + Layout.replicated(self.mesh, rank=0), + ) + + # c0 and c1 have different layouts. c1 and c2 have different shapes. + # c0 and c3 have different dtypes. + self.assertAllEqual(func0(c0), 2.0) + self.assertAllEqual(func0(c1), [3.0, 4.0]) + self.assertAllEqual(func0(c2), [5.0]) + self.assertAllEqual(func0(c3), 2) + + # func0 and func1 have different names. + self.assertAllEqual(func1(c0), 3.0) + + def testFunctionInputConstantFoldingCacheHits(self): + + @polymorphic_function.function + def add(a, b): + return a + b + + c0 = api.copy_to_mesh( + constant_op.constant(17.0), Layout.replicated(self.mesh, rank=0) + ) + c1 = api.copy_to_mesh( + constant_op.constant(21.0), Layout.replicated(self.mesh, rank=0) + ) + + stats1 = api._dtensor_device()._get_stats() + self.assertAllEqual(add(c0, c1), 38.0) + self.assertAllEqual(add(c0, c1), 38.0) + + # First call should miss and second should hit. + stats2 = api._dtensor_device()._get_stats() + diff = {key: stats2[key] - stats1[key] for key in stats1.keys()} + self.assertEqual(diff["function_manager.miss"], 1) + self.assertEqual(diff["function_manager.hit"], 1) + + def testFunctionInputConstantFoldingCacheMiss(self): + + @polymorphic_function.function + def add(a, b): + return a + b + + c0 = api.copy_to_mesh( + constant_op.constant(17.0), Layout.replicated(self.mesh, rank=0) + ) + c1 = api.copy_to_mesh( + constant_op.constant(21.0), Layout.replicated(self.mesh, rank=0) + ) + c2 = api.copy_to_mesh( + constant_op.constant(0.0), Layout.replicated(self.mesh, rank=0) + ) + + stats1 = api._dtensor_device()._get_stats() + # First call should log a cache miss. + self.assertAllEqual(add(c0, c1), 38.0) + + # Second call should also log a cache miss since second constant changed. + self.assertAllEqual(add(c0, c2), 17.0) + + # Third call should not log a cache miss since the same input as the prev. + self.assertAllEqual(add(c0, c2), 17.0) + + # Fourth call should log a cache miss since first input changed. + self.assertAllEqual(add(c1, c2), 21.0) + + stats2 = api._dtensor_device()._get_stats() + diff = {key: stats2[key] - stats1[key] for key in stats1.keys()} + self.assertEqual(diff["function_manager.miss"], 3) + self.assertEqual(diff["function_manager.hit"], 1) + + def testCacheWithRNG(self): + with api._dtensor_device()._default_layout( + Layout.replicated(self.mesh, rank=1)): + v0 = gen_stateless_random_ops.stateless_random_normal( + shape=[1], seed=[1, 2] + ) + + with api._dtensor_device()._default_layout( + Layout.replicated(self.mesh, rank=1)): + v1 = gen_stateless_random_ops.stateless_random_normal( + shape=[1], seed=[1, 2] + ) + v2 = gen_stateless_random_ops.stateless_random_normal( + shape=[2], seed=[1, 2] + ) + v3 = gen_stateless_random_ops.stateless_random_normal( + shape=[1], seed=[3, 4] + ) + + # v0 and v1 have same layouts. + self.assertAllEqual(v0, v1) + api.check_layout(v0, Layout.replicated(self.mesh, rank=1)) + api.check_layout(v1, Layout.replicated(self.mesh, rank=1)) + # v1 and v2 have different shapes. + self.assertNotEqual(v1.shape, v2.shape) + # v1 and v3 have different seeds. + self.assertNotEqual(v1.numpy(), v3.numpy()) + + def testCacheWithVariable(self): + c0 = api.copy_to_mesh( + constant_op.constant(1.0), Layout.replicated(self.mesh, rank=0) + ) + c1 = api.copy_to_mesh( + constant_op.constant([2.0, 3.0]), Layout.replicated(self.mesh, rank=1) + ) + a = constant_op.constant([4.0]) + b = constant_op.constant([5.0]) + c2 = api.pack( + [a, b], layout=Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=1) + ) + + v0 = d_variable.DVariable(c0) + v1 = d_variable.DVariable(c1) + v2 = d_variable.DVariable(c2) + + self.assertAllEqual(v0.read_value(), 1.0) + self.assertAllEqual(v1.read_value(), [2.0, 3.0]) + unpacked_tensor = api.unpack(v2.read_value()) + self.assertAllClose([4.0], unpacked_tensor[0]) + self.assertAllClose([5.0], unpacked_tensor[1]) + + @combinations.generate( + combinations.combine(size=[16, 40], same_value=[True, False]) + ) + def testManyFunctions(self, size, same_value): + r = range(100) + + values = [np.reshape(r[i : i + size], (4, size // 4)) for i in range(10)] + c_layout = Layout.replicated(self.mesh, rank=2) + values = [constant_op.constant(v, dtype=dtypes.float32) for v in values] + c0 = [api.copy_to_mesh(v, c_layout) for v in values] + + c0 = [c0[0 if same_value else i] for i in range(10)] + e0 = [values[0 if same_value else i] for i in range(10)] + stats1 = api._dtensor_device()._get_stats() + + for i in range(10): + # Use a special to ensure no conflicts with otherwise used names. + @polymorphic_function.function + def fn_31415926(c): + return math_ops.reduce_sum(c) + + self.assertAllEqual(fn_31415926(c0[i]).numpy(), np.sum(e0[i])) + + del fn_31415926 + gc.collect() + + stats2 = api._dtensor_device()._get_stats() + diff = diff_dicts(stats2, stats1) + self.assertEqual(diff["function_manager.size"], 0) + self.assertEqual(diff["kernel_cache.size"], 0) + self.assertEqual(diff["device_cache.size"], 0) + + @combinations.generate( + combinations.combine(size=[16, 40], same_value=[True, False]) + ) + def testManyEagerOps(self, size, same_value): + if self.mesh.device_type() != "TPU": + # For the CPU/GPU mesh, we have a shortcut that doesn't go through the + # MLIR, but run the eager op locally and broadcast to all the devices. + expected_cache_diff = 0 + expected_kernel_cache = 0 + expected_device_cache = 0 + expected_eager_pure_hit = 10 + else: + # TODO(b/287529295): Remove this branch after the TPU issue is fixed. + expected_device_cache = 0 + expected_eager_pure_hit = 0 + if same_value: + expected_cache_diff = 1 + expected_kernel_cache = 2 + else: + if size >= 20: + expected_cache_diff = 1 + expected_kernel_cache = 2 + else: + expected_cache_diff = 2 + expected_kernel_cache = 4 + + r = range(100) + c_layout = Layout.replicated(self.mesh, rank=2) + values = [np.reshape(r[i : i + size], (4, size // 4)) for i in range(10)] + values = [constant_op.constant(v, dtype=dtypes.float32) for v in values] + c0 = [api.copy_to_mesh(v, c_layout) for v in values] + + c0 = [c0[0 if same_value else i] for i in range(10)] + e0 = [values[0 if same_value else i] for i in range(10)] + + stats1 = api._dtensor_device()._get_stats() + + for i in range(10): + self.assertAllEqual(array_ops.identity(c0[i]).numpy(), e0[i]) + + gc.collect() + + stats2 = api._dtensor_device()._get_stats() + diff = diff_dicts(stats2, stats1) + + if same_value: + self.assertEqual(diff["function_manager.size"], expected_cache_diff) + self.assertEqual( + diff["eager_pure_optimization.hit"], expected_eager_pure_hit + ) + # TFRT doesn't use eager cache. + if not test_util.is_tfrt_enabled(): + self.assertEqual(diff["kernel_cache.size"], expected_kernel_cache) + self.assertEqual(diff["device_cache.size"], expected_device_cache) + else: + # FIXME(feyu): Update these when the leaks are fixed. + if size >= 20: + self.assertEqual(diff["function_manager.size"], expected_cache_diff) + self.assertEqual( + diff["eager_pure_optimization.hit"], expected_eager_pure_hit + ) + # TFRT doesn't use eager cache. + if not test_util.is_tfrt_enabled(): + self.assertEqual(diff["kernel_cache.size"], expected_kernel_cache) + self.assertEqual(diff["device_cache.size"], expected_device_cache) + else: + self.assertEqual(diff["function_manager.size"], expected_cache_diff) + self.assertEqual( + diff["eager_pure_optimization.hit"], expected_eager_pure_hit + ) + # TFRT doesn't use eager cache. + if not test_util.is_tfrt_enabled(): + self.assertEqual(diff["kernel_cache.size"], expected_kernel_cache) + self.assertEqual(diff["device_cache.size"], expected_device_cache) + + def testManyEagerOpsVaryInput(self): + c_layout = Layout.replicated(self.mesh, rank=10) + + c0 = constant_op.constant( + [[[[[[[[[[0, 1, 2, 3], [4, 5, 6, 7]]]]]]]]]], dtype=dtypes.float32 + ) + e0 = c0.numpy() + c0 = api.copy_to_mesh(c0, c_layout) + + for ax in range(10): + self.assertAllEqual( + math_ops.reduce_sum(c0, axis=ax).numpy(), np.sum(e0, axis=ax) + ) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/dtensor/python/tests/conv_test.py b/tensorflow/dtensor/python/tests/conv_test.py new file mode 100644 index 00000000000000..25cab09e1096ac --- /dev/null +++ b/tensorflow/dtensor/python/tests/conv_test.py @@ -0,0 +1,350 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for executing ops needed to implement image model.""" + +from absl.testing import parameterized +import numpy as np + +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager import backprop +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.platform import test + + +UNSHARDED = layout_lib.UNSHARDED +Mesh = layout_lib.Mesh +Layout = layout_lib.Layout + +BATCH_DIM = 'batch' +DEPTH_DIM = 'depth' +HEIGHT_DIM = 'height' +WIDTH_DIM = 'width' +BATCH_SIZE = 4 +DEPTH = 8 +HEIGHT = 12 +WIDTH = 12 +CHANNEL_IN = 1 +CHANNEL_OUT = 3 + + +class ConvOpTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + global_ids = test_util.create_device_ids_array((2, 2, 2)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = {} + for device in ('CPU', 'GPU', 'TPU'): + mesh_dict[device] = Mesh( + [BATCH_DIM, HEIGHT_DIM, WIDTH_DIM], + global_ids, + local_ids, + test_util.create_device_list((2, 2, 2), device), + ) + + self.mesh = self.configTestMesh(mesh_dict) + + self.replicated_2d = Layout.replicated(self.mesh, 2) + self.batch_sharded_2d = Layout.batch_sharded(self.mesh, BATCH_DIM, 2) + + @parameterized.named_parameters( + test_util.product( + *[ + [ + ( + 'Conv2D', + nn_ops.conv2d_v2, + (BATCH_SIZE, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bhwc,xy->by', + [1, 2, 1, 1], + ), + ( + 'Conv3D', + nn_ops.conv3d_v2, + (BATCH_SIZE, DEPTH, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bdhwc,xy->by', + [1, 1, 2, 1, 1], + ), + ], + [ + ('Eager', True), + ('Graph', False), + ], + [ + ('ReplicatedInput', 'replicated'), + ('BatchShardedInput', 'batch_sharded'), + ], + [ + ('ValidPadding', 'VALID'), + ('SamePadding', 'SAME'), + ], + ] + ) + ) + def testConvFollowedByEinsum(self, conv_op, input_size, kernel_size, + einsum_eq, strides, eager_mode, input_sharding, + padding): + x_in = constant_op.constant( + np.random.random(size=input_size), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_size), dtype=dtypes.float32 + ) + weight = constant_op.constant( + np.random.random(size=(2, 2)), dtype=dtypes.float32 + ) + + def conv_fn(inputs, img_kernel, layer_weights): + output = conv_op(inputs, img_kernel, strides=strides, padding=padding) + output = special_math_ops.einsum(einsum_eq, output, layer_weights) + return output + + if not eager_mode: + conv_fn = polymorphic_function.function(conv_fn) + + golden_result = conv_fn(x_in, kernel_in, weight) + + if input_sharding == 'replicated': + input_layout = Layout.replicated(self.mesh, len(input_size)) + output_layout = self.replicated_2d + elif input_sharding == 'batch_sharded': + input_layout = Layout.batch_sharded(self.mesh, BATCH_DIM, len(input_size)) + output_layout = self.batch_sharded_2d + + kernel_layout = Layout.replicated(self.mesh, len(kernel_size)) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + d_weight = numpy_util.pack_numpy(weight, self.replicated_2d) + d_result = conv_fn(d_x_in, d_kernel_in, d_weight) + + self.assertDTensorEqual(golden_result, output_layout, d_result) + + @parameterized.named_parameters( + test_util.product( + *[ + [ + ( + 'Conv2D', + nn_ops.conv2d_v2, + (BATCH_SIZE, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bhwc,xy->by', + [1, 1, 1, 1], + ), + ( + 'Conv3D', + nn_ops.conv3d_v2, + (BATCH_SIZE, DEPTH, HEIGHT, WIDTH, CHANNEL_IN), + (2, 2, 2, CHANNEL_IN, CHANNEL_OUT), + 'bdhwc,xy->by', + [1, 1, 1, 1, 1], + ), + ], + [ + ('ReplicatedInput', 'replicated'), + ('BatchShardedInput', 'batch_sharded'), + ], + [ + ('ValidPadding', 'VALID'), + ('SamePadding', 'SAME'), + ], + ] + ) + ) + def testConvFollowedByEinsumWithGradient(self, conv_op, input_size, + kernel_size, einsum_eq, strides, + input_sharding, padding): + x_in = constant_op.constant( + np.random.random(size=input_size), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_size), dtype=dtypes.float32 + ) + weight = constant_op.constant( + np.random.random(size=(2, 2)), dtype=dtypes.float32 + ) + + @polymorphic_function.function + def conv_fn(inputs, img_kernel, layer_weights): + with backprop.GradientTape() as tape: + tape.watch([inputs, img_kernel, layer_weights]) + output = conv_op(inputs, img_kernel, strides=strides, padding=padding) + output = special_math_ops.einsum(einsum_eq, output, layer_weights) + + inputs_grad, kernel_grad, weight_grad = tape.gradient( + output, [inputs, img_kernel, layer_weights]) + return output, inputs_grad, kernel_grad, weight_grad + + result, inputs_grad, kernel_grad, weight_grad = conv_fn( + x_in, kernel_in, weight) + + if input_sharding == 'replicated': + input_layout = Layout.replicated(self.mesh, len(input_size)) + output_layout = self.replicated_2d + elif input_sharding == 'batch_sharded': + input_layout = Layout.batch_sharded(self.mesh, BATCH_DIM, len(input_size)) + output_layout = self.batch_sharded_2d + + kernel_layout = Layout.replicated(self.mesh, len(kernel_size)) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + d_weight = numpy_util.pack_numpy(weight, self.replicated_2d) + d_result, d_inputs_grad, d_kernel_grad, d_weight_grad = conv_fn( + d_x_in, d_kernel_in, d_weight) + + self.assertDTensorEqual(result, output_layout, d_result) + # TODO(b/208700444): layout of input grads should match layout of input. + self.assertDTensorEqual( + inputs_grad, + Layout.replicated(self.mesh, len(input_size)), + d_inputs_grad, + ) + self.assertDTensorEqual(kernel_grad, kernel_layout, d_kernel_grad) + self.assertDTensorEqual(weight_grad, self.replicated_2d, d_weight_grad) + + +SPATIALLY_PARTITIONED_CONV_TEST_CASES = [ + [ + ('Case1', (BATCH_SIZE, 8, 16, CHANNEL_IN), (3, 5, CHANNEL_IN, + CHANNEL_OUT)), + ('Case2', (BATCH_SIZE, 8, 128, CHANNEL_IN), (3, 9, CHANNEL_IN, + CHANNEL_OUT)), + ], + [ + ('ValidPadding', 'VALID'), + ('SamePadding', 'SAME'), + ], + [ + ('Batch_1d_2x4', [BATCH_DIM, UNSHARDED, WIDTH_DIM, UNSHARDED], (2, 4)), + ('2d_2x4', [UNSHARDED, HEIGHT_DIM, WIDTH_DIM, UNSHARDED], (2, 4)), + ('Batch_2d_2x2x2', [BATCH_DIM, HEIGHT_DIM, WIDTH_DIM, + UNSHARDED], (2, 2, 2)), + ], +] + + +class SpatiallyPartitionedConvOpTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + # TODO(b/261485237): Enable CPU testing once CollectivePermute is supported + # on CPU's. + if not test_util.is_tpu_present(): + self.skipTest('This test only runs on TPUs.') + + def _create_mesh(self, mesh_dims, topology): + global_ids = test_util.create_device_ids_array(topology) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = {} + for device in ('CPU', 'GPU', 'TPU'): + mesh_dict[device] = Mesh( + mesh_dims, + global_ids, + local_ids, + test_util.create_device_list(topology, device), + ) + + return self.configTestMesh(mesh_dict) + + @parameterized.named_parameters( + test_util.product(*SPATIALLY_PARTITIONED_CONV_TEST_CASES)) + def testConv(self, input_shape, kernel_shape, padding, sharding_specs, + topology): + mesh_dims = [spec for spec in sharding_specs if spec != UNSHARDED] + mesh = self._create_mesh(mesh_dims, topology) + + x_in = constant_op.constant( + np.random.random(size=input_shape), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_shape), dtype=dtypes.float32 + ) + + expected_output = nn_ops.conv2d_v2( + x_in, kernel_in, strides=[1, 1, 1, 1], padding=padding + ) + + input_layout = Layout(sharding_specs, mesh) + kernel_layout = Layout.replicated(mesh, 4) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + d_output = nn_ops.conv2d_v2( + d_x_in, d_kernel_in, strides=[1, 1, 1, 1], padding=padding + ) + + self.assertDTensorEqual(expected_output, input_layout, d_output) + + @parameterized.named_parameters( + test_util.product(*SPATIALLY_PARTITIONED_CONV_TEST_CASES)) + def testConvWithGradient(self, input_shape, kernel_shape, padding, + sharding_specs, topology): + # TODO(b/208700444): add support for SPMD expansion of spatially partitioned + # conv backprop. + self.skipTest( + 'b/208700444: Spatially partitioned conv backprop not implemented.') + + mesh_dims = [spec for spec in sharding_specs if spec != UNSHARDED] + mesh = self._create_mesh(mesh_dims, topology) + + x_in = constant_op.constant( + np.random.random(size=input_shape), dtype=dtypes.float32 + ) + kernel_in = constant_op.constant( + np.random.random(size=kernel_shape), dtype=dtypes.float32 + ) + + @polymorphic_function.function + def conv_fn(inputs, img_kernel, padding): + with backprop.GradientTape() as tape: + tape.watch([inputs, img_kernel]) + output = nn_ops.conv2d_v2( + inputs, img_kernel, strides=[1, 1, 1, 1], padding=padding + ) + inputs_grad, kernel_grad = tape.gradient(output, [inputs, img_kernel]) + return output, inputs_grad, kernel_grad + + expected_output, expected_inputs_grad, expected_kernel_grad = conv_fn( + x_in, kernel_in, padding) + + input_layout = Layout(sharding_specs, mesh) + kernel_layout = Layout.replicated(mesh, 4) + + d_x_in = numpy_util.pack_numpy(x_in, input_layout) + d_kernel_in = numpy_util.pack_numpy(kernel_in, kernel_layout) + + d_output, d_inputs_grad, d_kernel_grad = conv_fn(d_x_in, d_kernel_in, + padding) + + self.assertDTensorEqual(expected_output, input_layout, d_output) + self.assertDTensorEqual(expected_inputs_grad, input_layout, d_inputs_grad) + self.assertDTensorEqual(expected_kernel_grad, kernel_layout, d_kernel_grad) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/mnist_test.py b/tensorflow/dtensor/python/tests/mnist_test.py new file mode 100644 index 00000000000000..5fd08f18414ef0 --- /dev/null +++ b/tensorflow/dtensor/python/tests/mnist_test.py @@ -0,0 +1,197 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DTensor MNIST test.""" + +from absl.testing import parameterized + +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_variable +from tensorflow.dtensor.python import input_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +_BATCH_DIM = 'batch' +_DEVICE_IDS = test_util.create_device_ids_array((2,)) +_ONE_D_MESH = layout_lib.Mesh( + [_BATCH_DIM], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'CPU'), +) +_ONE_D_TPU_MESH = layout_lib.Mesh( + [_BATCH_DIM], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'TPU'), +) +_BATCH_SIZE = 1024 +_STEPS = 5 +_LR = 1e-3 +_ATOL = 1 # absolute error becomes large as gradients approach zero. +_RTOL = 1e-3 +Layout = layout_lib.Layout + + +def mnist_fake_dataset(): + imgs = [] + labels = [] + for i in range(_STEPS * _BATCH_SIZE): + img = stateless_random_ops.stateless_random_uniform( + shape=(28, 28, 1), + seed=[1, i], + minval=0, + maxval=256, + dtype=dtypes.float32, + ) + imgs.append(img) + label = stateless_random_ops.stateless_random_uniform( + shape=(1,), seed=[2, i], minval=0, maxval=10, dtype=dtypes.int64 + ) + labels.append(label) + + return dataset_ops.DatasetV2.from_tensor_slices( + (array_ops_stack.stack(imgs), array_ops_stack.stack(labels)) + ) + + +def _run_step(inputs, w, b, k): + with backprop.GradientTape() as g: + g.watch([w, b]) + logits = nn_ops.conv2d_v2(inputs, k, strides=[1, 1, 1, 1], padding='SAME') + logits = array_ops.reshape(logits, [logits.shape[0], -1]) + logits = math_ops.matmul(logits, w) + logits = logits + b + loss = math_ops.reduce_sum(logits, axis=[0, 1]) + gw, gb = g.gradient(loss, [w, b]) + for v, v_grad in zip([w, b], [gw, gb]): + v.assign_sub(_LR * v_grad) + return gw, gb, loss + + +class DTensorMNISTTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorMNISTTest, self).setUp() + + global_ids = test_util.create_device_ids_array((2,)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = { + device: layout_lib.Mesh( + [_BATCH_DIM], + global_ids, + local_ids, + test_util.create_device_list((2,), device), + ) + for device in ['TPU', 'GPU', 'CPU'] + } + self.mesh = self.configTestMesh(mesh_dict) + + def init_var(self, mesh): + # Initialize TF randon normal variables(without using DTensor). + w_initializer = stateless_random_ops.stateless_random_normal( + shape=[28 * 28, 10], seed=[0, 1] + ) + b_initializer = stateless_random_ops.stateless_random_normal( + shape=[10], seed=[1, 2] + ) + # A filter with 3x3 shape, 1 input channel and 1 output channel. + k_initializer = stateless_random_ops.stateless_random_normal( + [3, 3, 1, 1], seed=[2, 3] + ) + + n_w = variables.Variable(w_initializer) + n_b = variables.Variable(b_initializer) + n_k = variables.Variable(k_initializer) + + # Initialize DTensor variables. + w_initializer_on_mesh = api.copy_to_mesh( + w_initializer, Layout.replicated(mesh, 2) + ) + b_initializer_on_mesh = api.copy_to_mesh( + b_initializer, Layout.replicated(mesh, rank=1) + ) + k_initializer_on_mesh = api.copy_to_mesh( + k_initializer, Layout.replicated(mesh, rank=4) + ) + + w = d_variable.DVariable(w_initializer_on_mesh) + b = d_variable.DVariable(b_initializer_on_mesh) + k = d_variable.DVariable(k_initializer_on_mesh) + + return (n_w, n_b, n_k), (w, b, k) + + @parameterized.named_parameters(('Eager', False), ('Function', True)) + def testMnist(self, on_function): + mnist_dataset = mnist_fake_dataset() + + (n_w, n_b, n_k), (w, b, k) = self.init_var(self.mesh) + + n_dataset = mnist_dataset.batch(_BATCH_SIZE, drop_remainder=True) + n_iter = iter(n_dataset) + + input_layout = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=4) + label_layout = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + dtensor_dataset = input_util.DTensorDataset( + dataset=mnist_dataset, + global_batch_size=_BATCH_SIZE, + mesh=self.mesh, + layouts=(input_layout, label_layout), + batch_dim=_BATCH_DIM, + ) + dtensor_iter = iter(dtensor_dataset) + + step_fn = ( + polymorphic_function.function(_run_step) if on_function else _run_step + ) + + # Training loop. + for _ in range(_STEPS): + # Normal run without DTensor. + n_input, _ = next(n_iter) + g_nw, g_nb, n_loss = step_fn(n_input, n_w, n_b, n_k) + + # DTensor Run + dtensor_input, _ = next(dtensor_iter) + with ops.device_v2(api.device_name()): + gw, gb, loss = step_fn(dtensor_input, w, b, k) + + loss_unpack = api.unpack(loss) + self.assertAllEqual(loss_unpack[0], loss_unpack[1]) + + self.assertAllClose(n_loss, loss, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(g_nw, gw, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(g_nb, gb, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(n_w, w, atol=_ATOL, rtol=_RTOL) + self.assertAllClose(n_b, b, atol=_ATOL, rtol=_RTOL) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/multi_client_input_util_test.py b/tensorflow/dtensor/python/tests/multi_client_input_util_test.py new file mode 100644 index 00000000000000..3538cb4e09b4e6 --- /dev/null +++ b/tensorflow/dtensor/python/tests/multi_client_input_util_test.py @@ -0,0 +1,548 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Multi-client tests for input_util.""" + +import os +from typing import Any, List, Mapping, Optional, Tuple + +from absl import logging +from absl.testing import parameterized +import numpy as np + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.dtensor.python import accelerator_util +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import config +from tensorflow.dtensor.python import input_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import mesh_util +from tensorflow.dtensor.python.tests import multi_client_test_util +from tensorflow.dtensor.python.tests import test_backend_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.data.experimental.service import server_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.eager import context +from tensorflow.python.framework import config as tf_config +from tensorflow.python.framework import device_spec +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_config +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + + +mp_context = test_backend_util.get_mp_context() + +# Multi-client test constants. +JOB_NAME = 'worker' +TF_DATA_SERVICE_JOB_NAME = 'dtensor_tf_data' +NUM_CLIENTS = 4 +NUM_DEVICES_PER_CLIENT = 4 + +# Mesh constants. +MESH_DIM_BATCH = 'batch' +MESH_DIM_HEIGHT = 'height' +MESH_DIM_WIDTH = 'width' + +# Data constants. +IMG_HEIGHT = 8 +IMG_WIDTH = 8 +IMG_CHANNELS = 3 + +UNSHARDED = layout_lib.UNSHARDED +Mesh = layout_lib.Mesh +Layout = layout_lib.Layout + + +def redirect_output(file_name): + # Redirect stderr/stdout to undeclared outputs on sponge. + artifact_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', '') + if artifact_dir: + with open(os.path.join(artifact_dir, file_name), 'wb') as fp: + os.dup2(fp.fileno(), 1) + os.dup2(fp.fileno(), 2) + + +def create_dispatcher(test_name, worker_addresses, port, pipe=None): + dispatcher = server_lib.DispatchServer( + config=server_lib.DispatcherConfig( + port=port, protocol='grpc', worker_addresses=worker_addresses + ) + ) + dispatcher.start() + if pipe is None: + # Dispatcher is not within subprocess, so do not block. + return dispatcher, dispatcher._address + else: + redirect_output(f'test-{test_name}-dispatcher.log') + pipe.send(dispatcher._address) + signal = pipe.recv() # blocks until a 'stop' signal is received + if signal == 'stop': + dispatcher._stop() + pipe.send('stopped') + else: + raise ValueError('Got unknown signal %s' % signal) + + +def create_worker(test_name, dispatcher_address, port=None, pipe=None): + worker = server_lib.WorkerServer( + config=server_lib.WorkerConfig( + port=port, dispatcher_address=dispatcher_address, protocol='grpc' + ) + ) + worker.start() + if pipe is None: + # Worker is not within subprocess, so do not block. + return worker, worker._address + else: + redirect_output(f'test-{test_name}-worker.log') + pipe.send(worker._address) + signal = pipe.recv() # blocks until a 'stop' signal is received + if signal == 'stop': + worker._stop() + pipe.send('stopped') + else: + raise ValueError('Got unknown signal %s' % signal) + + +class TFDataServiceCluster: + """tf.data service cluster with dispatcher and workers as subprocesses. + + To run the cluster in co-located mode, set `num_workers` to 0 and create the + tf.data service workers manually in each client process. + """ + + def __init__(self, + test_name, + num_workers, + worker_ports=None, + worker_addresses=None): + self._test_name = test_name + self._num_workers = num_workers + self._start_dispatcher(worker_addresses) + self._start_workers(worker_ports) + + def _start_dispatcher(self, worker_addresses, port=0): + self._pipe_to_dispatcher, dispatcher_pipe = mp_context.Pipe(True) + logging.info( + 'Starting remote dispatcher on port %d with worker addresses: %s', port, + worker_addresses) + self._dispatcher_process = mp_context.Process( + target=create_dispatcher, + args=(self._test_name, worker_addresses, port, dispatcher_pipe), + ) + self._dispatcher_process.start() + self._dispatcher_address = self._pipe_to_dispatcher.recv() + + def dispatcher_address(self): + return self._dispatcher_address + + def _start_workers(self, worker_ports=None): + self._workers = [] + self._worker_addresses = [] + self._worker_pipes = [] + for idx in range(self._num_workers): + port = worker_ports[idx] if worker_ports else None + self._start_worker(port) + + def _start_worker(self, port=None): + pipe_to_worker, worker_pipe = mp_context.Pipe(True) + logging.info( + 'Starting remote worker on port %d with dispatcher address: %s', port, + self._dispatcher_address) + worker_process = mp_context.Process( + target=create_worker, + args=(self._test_name, self._dispatcher_address, port, worker_pipe), + ) + worker_process.start() + worker_address = self._pipe_to_worker.recv() + self._workers.append(worker_process) + self._worker_addresses.append(worker_address) + self._worker_pipes.append(pipe_to_worker) + + def worker_addresses(self): + return self._worker_addresses + + def stop(self): + # Segfault logs may still be printed because clean exit of child processes + # is not always possible. This will not affect the outcome of the test. + logging.info('Will try to stop TFDataServiceCluster!') + + for idx in range(self._num_workers): + address = self._worker_addresses[idx] + pipe_to_worker = self._worker_pipes[idx] + logging.info('Stopping worker %s...', address) + pipe_to_worker.send('stop') + if pipe_to_worker.poll(2): + if pipe_to_worker.recv() == 'stopped': + logging.info('Successfully stopped worker %s', address) + self._workers[idx].terminate() + + logging.info('Stopping dispatcher...') + self._pipe_to_dispatcher.send('stop') + if self._pipe_to_dispatcher.poll(2): + if self._pipe_to_dispatcher.recv() == 'stopped': + logging.info('Successfully stopped dispatcher') + self._dispatcher_process.terminate() + + +def setup_local_devices(num_devices): + physical_cpus = tf_config.list_physical_devices('CPU') + tf_config.set_logical_device_configuration( + physical_cpus[0], + [context.LogicalDeviceConfiguration() for _ in range(num_devices)], + ) + + +def setup_client(client_id: int, test_name: str, env: Mapping[str, str], + num_local_devices: int): + """Set up a DTensor client for use in multi-client tests. + + Args: + client_id: the index of the client. + test_name: the name of the test under which this client is running, used To + identify the log file artifact containing the test output. + env: a dictionary of environment variables to update. + num_local_devices: number of local devices to set up. + """ + # Redirect client's stderr/stdout to undeclared outputs on sponge. + redirect_output(f'test-{test_name}-process-{client_id}.log') + + # Update any specified environment variables. + for var, val in env.items(): + os.environ[var] = val + + # Set up local devices. + setup_local_devices(num_local_devices) + + # Set up DTensor cluster and enable collectives. + accelerator_util.initialize_accelerator_system() + + +def run_client( + client_id: int, + test_name: str, + env: Mapping[str, str], + num_local_devices: int, + dispatcher_address: str, + worker_port: int, + batch_size: int, + dataset_paths: List[str], + mesh: Mesh, + batch_dim: Optional[str], + layouts: Tuple[Layout, Layout], +) -> List[Tuple[Any, Any]]: + # Co-located tf.data service mode. It is important to hold the worker object + # until the end otherwise it will get garbage collected. + worker, worker_address = create_worker( # pylint: disable=unused-variable + test_name, dispatcher_address, port=worker_port) + logging.info( + 'tf.data service worker running at %s', + worker_address, + ) + + setup_client(client_id, test_name, env, num_local_devices) + + def decode_fn(record_bytes): + decoded = parsing_ops.parse_single_example_v2( + serialized=record_bytes, + features={ + 'idx': parsing_config.FixedLenFeature([], dtype=dtypes.int64), + 'elem': parsing_config.FixedLenFeature([], dtype=dtypes.string), + }, + ) + parsed_elem = gen_parsing_ops.parse_tensor(decoded['elem'], dtypes.int32) + elem = check_ops.ensure_shape( + parsed_elem, [IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS] + ) + return decoded['idx'], elem + + dataset = dataset_ops.DatasetV2.from_tensor_slices(dataset_paths) + dataset = dataset.interleave(readers.TFRecordDatasetV2) + dataset = dataset.map(decode_fn) + + tf_data_service_config = input_util.TFDataServiceConfig( + dispatcher_address=dispatcher_address, job_name=TF_DATA_SERVICE_JOB_NAME + ) + d_dataset = input_util.DTensorDataset( + dataset=dataset, + global_batch_size=batch_size, + mesh=mesh, + layouts=layouts, + batch_dim=batch_dim, + tf_data_service_config=tf_data_service_config, + ) + + # Subprocesses cannot return a sharded DTensor as it triggers a copy and + # copying non-replicated DTensors is not supported. So instead we unpack it + # and return the component tensors. + ret = [] + for batch_idx, elem in d_dataset: + n_batch_idx = api.unpack(batch_idx) + n_elem = api.unpack(elem) + ret.append((n_batch_idx, n_elem)) + return ret + + +class MultiClientDTensorDatasetTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + logging.info('Check per client log in Test artifacts.') + + self.server_ports = [ + multi_client_test_util.pick_unused_port() for _ in range(NUM_CLIENTS) + ] + + self.worker_ports = [ + multi_client_test_util.pick_unused_port() for _ in range(NUM_CLIENTS) + ] + worker_addresses = [f'localhost:{port}' for port in self.worker_ports] + self.cluster = TFDataServiceCluster( + test_name=self._testMethodName, + num_workers=0, # Co-located mode. + worker_addresses=worker_addresses) + + def tearDown(self): + super().tearDown() + self.cluster.stop() + + def write_dataset(self, dataset, num_files, num_elems): + """Writes a dataset_ops.DatasetV2 to multiple files.""" + dataset_paths = [] + dataset_iter = iter(dataset) + + for file_idx in range(num_files): + dataset_path = os.path.join(self.get_temp_dir(), + f'dataset-{file_idx}.tfrecords') + dataset_paths.append(dataset_path) + with tf_record.TFRecordWriter(dataset_path) as writer: + for _ in range(num_elems // num_files): + idx, elem = next(dataset_iter) + elem_bytes = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'idx': feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[idx]) + ), + 'elem': feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[io_ops.serialize_tensor(elem).numpy()] + ) + ), + } + ) + ).SerializeToString() + writer.write(elem_bytes) + + return dataset_paths + + @parameterized.product( + ( + { + # batch=4 x height=2 x width=2 + # 1 replica per client. + 'mesh_dims': [(MESH_DIM_BATCH, 4), + (MESH_DIM_HEIGHT, 2), + (MESH_DIM_WIDTH, 2)], + }, { + # batch=4 x height=2 x width=2 (transposed) + # 1 replica per client with reordered local partitions. + 'mesh_dims': [(MESH_DIM_BATCH, 4), + (MESH_DIM_WIDTH, 2), + (MESH_DIM_HEIGHT, 2)], + }, { + # batch=8 x height=2 x width=1 + # 2 replicas per client. + 'mesh_dims': [(MESH_DIM_BATCH, 8), + (MESH_DIM_HEIGHT, 2), + (MESH_DIM_WIDTH, 1)], + }, { + # batch=8 x height=2 x width=1 (transposed) + # 2 replicas per client with reordered partitions. + 'mesh_dims': [(MESH_DIM_BATCH, 8), + (MESH_DIM_WIDTH, 1), + (MESH_DIM_HEIGHT, 2)], + }, { + # batch=2 x height=4 x width=2 + # 1 replica split over 2 clients. + 'mesh_dims': [(MESH_DIM_BATCH, 2), + (MESH_DIM_HEIGHT, 4), + (MESH_DIM_WIDTH, 2)], + }, { + # batch=2 x height=4 x width=2 (transposed) + # 1 replica split over 2 clients with reordered partitions. + 'mesh_dims': [(MESH_DIM_BATCH, 2), + (MESH_DIM_WIDTH, 2), + (MESH_DIM_HEIGHT, 4)], + }, + ), + ( + { + # Replicated + 'idx_sharding': [UNSHARDED], + 'images_sharding': [UNSHARDED, UNSHARDED, UNSHARDED, UNSHARDED], + }, { + # Batch sharded + 'idx_sharding': [MESH_DIM_BATCH], + 'images_sharding': + [MESH_DIM_BATCH, UNSHARDED, UNSHARDED, UNSHARDED], + }, { + # Spatially sharded + 'idx_sharding': [UNSHARDED], + 'images_sharding': + [UNSHARDED, MESH_DIM_HEIGHT, MESH_DIM_WIDTH, UNSHARDED], + }, { + # Batch and spatially sharded + 'idx_sharding': [MESH_DIM_BATCH], + 'images_sharding': + [MESH_DIM_BATCH, MESH_DIM_HEIGHT, MESH_DIM_WIDTH, UNSHARDED], + } + )) + def testMultiClientIter(self, mesh_dims, idx_sharding, images_sharding): + num_batches = 4 + batch_size = 16 + num_elems = num_batches * batch_size + + images = stateless_random_ops.stateless_random_uniform( + [num_elems, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS], + seed=(1, 2), + minval=0, + maxval=255, + dtype=dtypes.int32, + ) + dataset = dataset_ops.DatasetV2.from_tensor_slices(images) + + # Enumerate the dataset elements to make it easier to identify the batches + # returned by the DTensorDataset. + dataset = dataset.enumerate() + + # Store a mapping of index to dataset elements which can be looked up later + # to identify the batches returned by the DTensorDataset. + all_elems = {idx.numpy(): elem for idx, elem in dataset} + + # Write the dataset and shard it among multiple files. + dataset_paths = self.write_dataset( + dataset, num_files=8, num_elems=num_elems) + + # Construct args for starmap. + args = [] + mesh_dim_names, mesh_dim_sizes = zip(*mesh_dims) + global_device_ids = test_util.create_device_ids_array(mesh_dim_sizes) + device_ids_split = np.split(np.ravel(global_device_ids), NUM_CLIENTS) + dtensor_jobs = [ + f'localhost:{self.server_ports[i]}' for i in range(NUM_CLIENTS) + ] + + for client_id in range(NUM_CLIENTS): + # Manually specify DTensor environment variables since we are in a test + # environment. + env = { + config._DT_CLIENT_ID: str(client_id), + config._DT_JOB_NAME: str(JOB_NAME), + config._DT_JOBS: ','.join(dtensor_jobs) + } + + local_device_ids = device_ids_split[client_id].tolist() + local_devices = [ + device_spec.DeviceSpecV2( # pylint: disable=g-complex-comprehension + job=JOB_NAME, + replica=0, + task=client_id, + device_type='CPU', + device_index=i, + ) + for i in range(len(local_device_ids)) + ] + mesh = Mesh( + dim_names=mesh_dim_names, + global_device_ids=global_device_ids, + local_device_ids=local_device_ids, + local_devices=local_devices, + ) + idx_layout = Layout(idx_sharding, mesh) + images_layout = Layout(images_sharding, mesh) + batch_dim = MESH_DIM_BATCH if MESH_DIM_BATCH in images_sharding else None + + args.append((client_id, self._testMethodName, env, NUM_DEVICES_PER_CLIENT, + self.cluster.dispatcher_address(), + self.worker_ports[client_id], batch_size, dataset_paths, + mesh, batch_dim, (idx_layout, images_layout))) + + def get_results(): + # Run the DTensor client processes and get the DTensor dataset components. + with mp_context.Pool(NUM_CLIENTS) as pool: + results = pool.starmap(run_client, args) + pool.close() + pool.join() + + return results + + # TODO(b/271162918): fix multi-client use case. + with self.assertRaises(NotImplementedError): + results = get_results() + + return + # pylint: disable=unreachable + + # Create a mesh on the main test process. The tensor components returned + # from each DTensor client subprocess will be packed onto this mesh to + # verify correctness. + test_mesh = mesh_util.create_mesh( + mesh_dims=mesh_dims, + devices=[ + 'CPU:%d' % i for i in range(NUM_CLIENTS * NUM_DEVICES_PER_CLIENT) + ]) + test_mesh = self.configTestMesh({'CPU': test_mesh}) + idx_test_layout = Layout(idx_sharding, test_mesh) + images_test_layout = Layout(images_sharding, test_mesh) + + for batch_elems in zip(*results): + # Collect the tensor components returned from each client. + idx_components = [] + images_components = [] + for client_id in range(NUM_CLIENTS): + local_idx, local_images = batch_elems[client_id] + idx_components.extend(local_idx) + images_components.extend(local_images) + + # Pack the dataset elements into a DTensor on the test mesh. + d_idx = api.pack(idx_components, idx_test_layout) + d_images = api.pack(images_components, images_test_layout) + + # Get the batch of elements from the original dataset using the element + # indices. + batch_stack = [] + for elem_idx in d_idx: + batch_stack.append(all_elems.pop(elem_idx.numpy())) + batch = array_ops_stack.stack(batch_stack) + + self.assertDTensorEqual(batch, images_test_layout, d_images) + + self.assertEmpty( + all_elems, 'Not all batches were returned by DTensorDataset.') + + +if __name__ == '__main__': + test_backend_util.handle_test_main(test.main) diff --git a/tensorflow/dtensor/python/tests/multi_client_test_util.py b/tensorflow/dtensor/python/tests/multi_client_test_util.py index 35d3e7aa10ef98..dd4a69f14f77e4 100644 --- a/tensorflow/dtensor/python/tests/multi_client_test_util.py +++ b/tensorflow/dtensor/python/tests/multi_client_test_util.py @@ -31,6 +31,11 @@ 'Number of clients. 0 for local mode. 2 is the only allowed value for TPU.') +def pick_unused_port(): + """Helper function to return an unused port.""" + return portpicker.pick_unused_port() + + def multi_client_main(client_config_function): """Creates a Flock of TensorFlow Processes on localhost.""" flags.FLAGS(sys.argv, known_only=True) @@ -49,12 +54,11 @@ def multi_client_main(client_config_function): # Inverts the order of ports intentionally to rule out ordering bugs. server_ports = sorted( - [portpicker.pick_unused_port() for _ in range(num_process)], reverse=True) - - additional_ports = sorted( - [portpicker.pick_unused_port() for _ in range(num_process)] + [pick_unused_port() for _ in range(num_process)], reverse=True ) + additional_ports = sorted([pick_unused_port() for _ in range(num_process)]) + # Starts processes procs = [] for client_idx in range(num_process): @@ -138,4 +142,3 @@ def run_client(idx, num_clients, server_ports, additional_ports, # The following function call never returns. tf_test.main() - diff --git a/tensorflow/dtensor/python/tests/numerics_test.py b/tensorflow/dtensor/python/tests/numerics_test.py new file mode 100644 index 00000000000000..60bb7995adf6e2 --- /dev/null +++ b/tensorflow/dtensor/python/tests/numerics_test.py @@ -0,0 +1,125 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for numerics in DTensor Ops.""" + +import os + +from absl.testing import parameterized +import numpy as np + +from tensorflow.dtensor.python import accelerator_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh +UNSHARDED = layout_lib.UNSHARDED +_MESH_DIM_X = 'x' +_MESH_DIM_Y = 'y' +_MESH_DIMS = [_MESH_DIM_X, _MESH_DIM_Y] + + +class NumericTest(test_util.DTensorBaseTest): + + def setUp(self): + super(NumericTest, self).setUp() + + self.skipForDeviceType(['TPU'], + 'all tests require 8 TPU cores.', + unless_device_count_equals_to=8) + + test_util.reset_logical_devices('CPU', 8) + accelerator_util.initialize_accelerator_system() + + self.stateless_random_seed = [0, 1] + + def _create_mesh(self, topology, device): + device_ids = test_util.create_device_ids_array(topology) + return Mesh( + _MESH_DIMS, + device_ids, + np.ravel(device_ids).tolist(), + test_util.create_device_list(topology, device), + ) + + # Tests AllReduce numerics with and without mixed precision reduce enabled, + # based on go/dtensor-numerics. + @parameterized.named_parameters(('_without_mixed_precision_reduce', False), + ('_with_mixed_precision_reduce', True)) + def test_all_reduce(self, enable_mixed_precision_reduce): + if enable_mixed_precision_reduce: + os.environ['DTENSOR_ENABLE_MIXED_PRECISION_REDUCE'] = '' + # Override group size since we are testing on smaller mesh. + os.environ['DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE'] = '4' + else: + if 'DTENSOR_ENABLE_MIXED_PRECISION_REDUCE' in os.environ: + del os.environ['DTENSOR_ENABLE_MIXED_PRECISION_REDUCE'] + + @polymorphic_function.function + def _compute_reduction(inp): + return math_ops.reduce_sum(inp, axis=[2]) + + input_tensor = stateless_random_ops.stateless_random_uniform( + shape=(8, 8, 8, 64), + seed=self.stateless_random_seed, + minval=-5.0, + maxval=5.0, + dtype=dtypes.bfloat16, + ) + expected = _compute_reduction(input_tensor) + + # Compute reduction on 8x1, since dim 2 is unsharded AllReduce will not be + # needed. + mesh_8x1 = self._create_mesh((8, 1), 'TPU') + input_8x1 = numpy_util.pack_numpy( + input_tensor, + Layout([_MESH_DIM_X, UNSHARDED, UNSHARDED, UNSHARDED], mesh_8x1), + ) + result_8x1 = _compute_reduction(input_8x1) + result_8x1_np = numpy_util.to_numpy(result_8x1) + + # Compute reduction on 1x8, AllReduce will be needed since dim 2 is sharded. + mesh_1x8 = self._create_mesh((1, 8), 'TPU') + input_1x8 = numpy_util.pack_numpy( + input_tensor, + Layout([_MESH_DIM_X, UNSHARDED, _MESH_DIM_Y, UNSHARDED], mesh_1x8), + ) + result_1x8 = _compute_reduction(input_1x8) + result_1x8_np = numpy_util.to_numpy(result_1x8) + + self.assertEqual(result_8x1.dtype, dtypes.bfloat16) + self.assertEqual(result_1x8.dtype, dtypes.bfloat16) + + # Mixed precision does not apply since AllReduce was not used, result will + # always be close to the expected value. + self.assertAllClose(result_8x1_np, expected, atol=1e-5, rtol=1e-5) + + # AllReduce was needed, so result will be more accurate if mixed precision + # is enabled. + if enable_mixed_precision_reduce: + self.assertAllClose(result_1x8_np, expected, atol=1e-5, rtol=1e-5) + else: + self.assertNotAllClose(result_1x8_np, expected, atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/rng_test.py b/tensorflow/dtensor/python/tests/rng_test.py new file mode 100644 index 00000000000000..41b61119372b0b --- /dev/null +++ b/tensorflow/dtensor/python/tests/rng_test.py @@ -0,0 +1,665 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import parameterized + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_variable +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.dtensor.python.tests import test_util_ops +from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver +from tensorflow.python.eager import remote +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_bitwise_ops +from tensorflow.python.ops import gen_stateful_random_ops +from tensorflow.python.ops import gen_stateless_random_ops_v2 +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.tpu import device_assignment as device_assignment_lib + +# pylint: enable=g-direct-tensorflow-import + +# Makes a 2-D mesh with dimensions as, X(2) and Y(4). +_MESH_DIM_X = 'x' +_MESH_DIM_Y = 'y' +_MESH_DIMS = [_MESH_DIM_X, _MESH_DIM_Y] + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + +# Create a random local IDs to make tests more challenging. +_LOCAL_IDS = [7, 3, 1, 4, 2, 0, 6, 5] +# The row and col indices for each local id, e.g., 7 is (row=1, col=3) +_ROW_INDEX = [i / 4 for i in _LOCAL_IDS] +_COL_INDEX = [i % 4 for i in _LOCAL_IDS] + +# The index of local id for the row head. +# +# For example, local id 7 is on row 1, the head is local id 4, whose index in +# _LOCAL_IDS is 3, i.e., _LOCAL_IDS[3] == 4 +_ROW_0_HEAD = 3 +_ROW_1_HEAD = 5 +_ROW_HEAD = [3, 5, 5, 3, 5, 5, 3, 3] + +# The index of local id for the col head. Similar to row id before. +_COL_0_HEAD = 5 +_COL_1_HEAD = 2 +_COL_2_HEAD = 4 +_COL_3_HEAD = 1 +_COL_HEAD = [1, 1, 2, 5, 4, 5, 4, 2] + +_tpu_strategy = None + + +def _call_op(op, seed, shape, dtype, key, counter, alg, minval, maxval, + op_version): + if op_version == 'V1': + return op(shape=shape, seed=seed, dtype=dtype) + elif op_version == 'V2': + return op(shape=shape, key=key, counter=counter, alg=alg, dtype=dtype) + elif op_version == 'V2_RANGE': + return op( + shape=shape, + key=key, + counter=counter, + alg=alg, + minval=minval, + maxval=maxval) + else: + raise ValueError('op_version argument was invalid.') + + +def _call_dtensor_op(op, seed, shape, dtype, key, counter, alg, minval, maxval, + op_version, mesh): + if op_version == 'V1': + return op(shape=shape, seed=seed, dtype=dtype) + + shape = numpy_util.pack_numpy( + constant_op.constant(shape), Layout.replicated(mesh, 1) + ) + key = numpy_util.pack_numpy(key, Layout.replicated(mesh, 1)) + counter = numpy_util.pack_numpy(counter, Layout.replicated(mesh, 1)) + + if op_version == 'V2': + return op(shape=shape, key=key, counter=counter, alg=alg, dtype=dtype) + elif op_version == 'V2_RANGE': + return op( + shape=shape, + key=key, + counter=counter, + alg=alg, + minval=minval, + maxval=maxval) + else: + raise ValueError('op_version argument was invalid.') + + +def get_tpu_strategy(): + """Returns a single-core TPUStrategy.""" + global _tpu_strategy + if _tpu_strategy is not None: + return _tpu_strategy + + resolver = tpu_cluster_resolver.TPUClusterResolver(tpu='') + remote.connect_to_cluster(resolver) + topology = tpu_cluster_resolver.initialize_tpu_system(resolver) + device_assignment = device_assignment_lib.DeviceAssignment.build( + topology, num_replicas=1 + ) + strategy = tpu_strategy.TPUStrategyV2( + resolver, experimental_device_assignment=device_assignment + ) + _tpu_strategy = strategy + return strategy + + +def rng_op_spmd(op, + device_id, + seed, + shape, + dtype, + key, + counter, + alg, + minval, + maxval, + op_version, + device_index_fn, + full_replicated=False, + is_tpu=False): + + if not is_tpu: + return rng_op_spmd_fn( + op, + device_id, + seed, + shape, + dtype, + key, + counter, + alg, + minval, + maxval, + op_version, + device_index_fn, + full_replicated=full_replicated) + + # As of 2021-April, TPU eager and multi-device function produce different + # stateless rng results compared with bridge compiled function. As DTensor + # uses bridge to lower TPU function by default, we need to create a + # TPUStrategy for single core and invoke `run` on it. + @polymorphic_function.function + def tpu_fn(device_id, seed): + return rng_op_spmd_fn( + op, + device_id, + seed, + shape, + dtype, + key, + counter, + alg, + minval, + maxval, + op_version, + device_index_fn, + full_replicated=full_replicated) + + return get_tpu_strategy().run(tpu_fn, args=(device_id, seed)) + + +def rng_op_spmd_fn(op, + device_id, + seed, + shape, + dtype, + key, + counter, + alg, + minval, + maxval, + op_version, + device_index_fn, + full_replicated=False): + if full_replicated: + # TODO(bfontain,xiejw): Consider to make this consistent with non-replicated + # case. Seems very confusing. + new_seed, new_key = seed, key + else: + # Runs on TF2 non-DTensor pure eager. This code should align the same + # logic in RandomOpSPMDExpander. + x_cord = device_id // 4 + y_cord = device_id % 4 + device_index = device_index_fn(x_cord, y_cord) + device_id_seed = device_index * 65536 + 65521 + new_seed = gen_bitwise_ops.bitwise_xor(seed, device_id_seed) + new_key = gen_bitwise_ops.bitwise_xor( + key, math_ops.cast(device_id_seed, dtype=dtypes.uint64) + ) + return _call_op( + op=op, + seed=new_seed, + shape=shape, + dtype=dtype, + key=new_key, + counter=counter, + alg=alg, + minval=minval, + maxval=maxval, + op_version=op_version) + + +class DTensorRNGTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorRNGTest, self).setUp() + global_ids = test_util.create_device_ids_array((2, 4)) + local_ids = _LOCAL_IDS + mesh_dict = { + device: Mesh( + [_MESH_DIM_X, _MESH_DIM_Y], + global_ids, + local_ids, + test_util.create_device_list((2, 4), device), + ) + for device in ('CPU', 'GPU', 'TPU') + } + self.mesh = self.configTestMesh(mesh_dict) + + # Creates a bunch of common layouts used by tests later. + self.replicated_layout_2d = Layout.replicated(self.mesh, rank=2) + self.shardings = { + 'batch': Layout.batch_sharded, + 'inner': Layout.inner_sharded + } + # Creates a bunch of parameters for rng V2 ops + self.key = constant_op.constant([123], dtype=dtypes.uint64) + self.counter = constant_op.constant([1, 1], dtype=dtypes.uint64) + self.alg = 1 + self.minval = 1 + self.maxval = 100 + + @parameterized.named_parameters(test_util_ops.RANDOM_OPS) + def testStatelessRNGWithFullyReplicated(self, op, dtype, op_version): + layout = self.replicated_layout_2d + shape = [16, 16] + seed = [123, 321] + + with ops.device_v2(api.device_name()): + with api._dtensor_device()._default_layout(layout): + b = _call_dtensor_op( + op=op, + seed=seed, + shape=shape, + dtype=dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + mesh=self.mesh) + + api.check_layout(b, layout) + self.assertListEqual(shape, list(b.shape)) + + b = [tensor.numpy() for tensor in api.unpack(b)] + for i in range(self.mesh.num_local_devices() - 1): + self.assertAllEqual(b[i], b[i + 1]) + + @parameterized.named_parameters(test_util_ops.RANDOM_OPS) + def testStatelessRNGWithFullyReplicatedComparingWithNonDTensor( + self, op, dtype, op_version): + + layout = self.replicated_layout_2d + shape = [16, 16] + seed = [123, 321] + + with ops.device_v2(api.device_name()): + with api._dtensor_device()._default_layout(layout): + b = _call_dtensor_op( + op=op, + seed=seed, + shape=shape, + dtype=dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + mesh=self.mesh) + + api.check_layout(b, layout) + self.assertListEqual(shape, list(b.shape)) + + b = [tensor.numpy() for tensor in api.unpack(b)] + + local_shape = shape + for index, device_id in enumerate(_LOCAL_IDS): + self.assertAllEqual( + b[index], + rng_op_spmd( + op, + device_id, + seed, + local_shape, + dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + device_index_fn=None, # not needed + full_replicated=True, + is_tpu=self.mesh.device_type().upper() == 'TPU')) + + @parameterized.named_parameters( + test_util_ops.expand_test_config( + test_util_ops.RANDOM_OPS, + [ + { + 'dim': _MESH_DIM_X, + 'shard_type': 'batch', + }, + { + 'dim': _MESH_DIM_Y, + 'shard_type': 'batch', + }, + { + 'dim': _MESH_DIM_X, + 'shard_type': 'inner', + }, + {'dim': _MESH_DIM_Y, 'shard_type': 'inner'}, + ], + ) + ) + def testStatelessRNGOpsWithSingleDimensionSharded(self, op, dtype, op_version, + dim, shard_type): + shape = [128, 128] + seed = [123, 321] + sharding = self.shardings[shard_type] + layout = sharding(self.mesh, dim, rank=2) + + # Raw rng Ops do not have inputs, so we need to place the Op DTensor device + # explicitly. + with ops.device_v2(api.device_name()): + with api._dtensor_device()._default_layout(layout): + b = _call_dtensor_op( + op=op, + seed=seed, + shape=shape, + dtype=dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + mesh=self.mesh) + + api.check_layout(b, layout) + b = [tensor.numpy() for tensor in api.unpack(b)] + + if dim == _MESH_DIM_X: + if shard_type == 'batch': + self.assertAllEqual(b[0].shape, [64, 128]) + else: + assert shard_type == 'inner' + self.assertAllEqual(b[0].shape, [128, 64]) + + # first check that each component is same as the row header. + for i in range(self.mesh.num_local_devices()): + self.assertAllEqual(b[i], b[_ROW_HEAD[i]]) + # then check the row header are NOT identital. + self.assertNotAllEqual(b[_ROW_0_HEAD], b[_ROW_1_HEAD]) + + elif dim == _MESH_DIM_Y: + if shard_type == 'batch': + self.assertAllEqual(b[0].shape, [32, 128]) + else: + assert shard_type == 'inner' + self.assertAllEqual(b[0].shape, [128, 32]) + + # first check elements in same columns are identical + for i in range(self.mesh.num_local_devices()): + self.assertAllEqual(b[i], b[_COL_HEAD[i]]) + + col_heads = [_COL_0_HEAD, _COL_1_HEAD, _COL_2_HEAD, _COL_3_HEAD] + # then check the column header are not identital (mutually) + for i in range(self.mesh.num_local_devices() - 1): + for j in range(self.mesh.num_local_devices()): + if i == j: + continue + if i in col_heads and j in col_heads: + self.assertNotAllEqual(b[i], b[j]) + + else: + self.fail('should not reach here.') + + @parameterized.named_parameters( + test_util_ops.expand_test_config( + test_util_ops.RANDOM_OPS, + [ + { + 'dim': _MESH_DIM_X, + 'shard_type': 'batch', + }, + { + 'dim': _MESH_DIM_Y, + 'shard_type': 'batch', + }, + { + 'dim': _MESH_DIM_X, + 'shard_type': 'inner', + }, + {'dim': _MESH_DIM_Y, 'shard_type': 'inner'}, + ], + ) + ) + def testStatelessRNGOpsWithSingleDimensionShardedComparingWithNonDTensor( + self, op, dtype, op_version, dim, shard_type): + + shape = [128, 128] + seed = [123, 321] + sharding = self.shardings[shard_type] + layout = sharding(self.mesh, dim, rank=2) + + # Raw rng Ops do not have inputs, so we need to place the Op DTensor device + # explicitly. + with ops.device_v2(api.device_name()): + with api._dtensor_device()._default_layout(layout): + b = _call_dtensor_op( + op=op, + seed=seed, + shape=shape, + dtype=dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + mesh=self.mesh) + + api.check_layout(b, layout) + b = [tensor.numpy() for tensor in api.unpack(b)] + + if dim == _MESH_DIM_X: + if shard_type == 'batch': + local_shape = [64, 128] + else: + local_shape = [128, 64] + + def device_index_fn(x_cord, y_cord): + # See todo of device_index_fn in 2d sharding case. + del y_cord + return x_cord + + for index, device_id in enumerate(_LOCAL_IDS): + self.assertAllEqual( + b[index], + rng_op_spmd( + op, + device_id, + seed, + local_shape, + dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + device_index_fn=device_index_fn, + is_tpu=self.mesh.device_type().upper() == 'TPU')) + elif dim == _MESH_DIM_Y: + if shard_type == 'batch': + local_shape = [32, 128] + else: + local_shape = [128, 32] + + def device_index_fn(x_cord, y_cord): + # See todo of device_index_fn in 2d sharding case. note this case is + # particulary interesting as 2*y_cord is more natual. + del x_cord + return y_cord + + for index, device_id in enumerate(_LOCAL_IDS): + self.assertAllEqual( + b[index], + rng_op_spmd( + op, + device_id, + seed, + local_shape, + dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + device_index_fn=device_index_fn, + is_tpu=self.mesh.device_type().upper() == 'TPU')) + + else: + self.fail('should not reach here.') + + @parameterized.named_parameters(test_util_ops.RANDOM_OPS) + def testStatelessRNGOpsWith2DSharding(self, op, dtype, op_version): + shape = [128, 128] + seed = [123, 321] + layout = Layout([_MESH_DIM_Y, _MESH_DIM_X], self.mesh) + + # Raw rng Ops do not have inputs, so we need to place the Op DTensor device + # explicitly. + with ops.device_v2(api.device_name()): + with api._dtensor_device()._default_layout(layout): + b = _call_dtensor_op( + op=op, + seed=seed, + shape=shape, + dtype=dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + mesh=self.mesh) + + api.check_layout(b, layout) + b = [tensor.numpy() for tensor in api.unpack(b)] + + # check all raw components are not identital (mutually) + for i in range(self.mesh.num_local_devices() - 1): + for j in range(self.mesh.num_local_devices()): + if i == j: + continue + self.assertNotAllEqual(b[i], b[j]) + + @parameterized.named_parameters(test_util_ops.RANDOM_OPS) + def testStatelessRNGOpsWith2DShardingComparingWithNonDTensor( + self, op, dtype, op_version): + shape = [128, 128] + seed = [123, 321] + layout = Layout([_MESH_DIM_Y, _MESH_DIM_X], self.mesh) + local_shape = [128 // 4, 128 // 2] + + # Raw rng Ops do not have inputs, so we need to place the Op DTensor device + # explicitly. + with ops.device_v2(api.device_name()): + with api._dtensor_device()._default_layout(layout): + b = _call_dtensor_op( + op=op, + seed=seed, + shape=shape, + dtype=dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + mesh=self.mesh) + + api.check_layout(b, layout) + b = [tensor.numpy() for tensor in api.unpack(b)] + + def device_index_fn(x_cord, y_cord): + # TODO(bfontain,xiejw): Currently, the device index is x+2y. But it is + # more natual to use 4x+y for a mesh. Consider to change this + # once all correctness tests are done. + return x_cord + 2 * y_cord + + for index, device_id in enumerate(_LOCAL_IDS): + self.assertAllEqual( + b[index], + rng_op_spmd( + op, + device_id, + seed, + local_shape, + dtype, + key=self.key, + counter=self.counter, + alg=self.alg, + minval=self.minval, + maxval=self.maxval, + op_version=op_version, + device_index_fn=device_index_fn, + is_tpu=self.mesh.device_type().upper() == 'TPU')) + + def testRNGReadAndSkip(self): + replicated_layout = Layout.replicated(self.mesh, 1) + a = constant_op.constant([1, 2, 3], dtype=dtypes.int64) + v = variables.Variable(a) + expected = gen_stateful_random_ops.rng_read_and_skip( + resource=v.handle, + alg=1, + delta=constant_op.constant(1, dtype=dtypes.uint64), + ) + + a = numpy_util.pack_numpy(a, replicated_layout) + v = d_variable.DVariable(a) + got = gen_stateful_random_ops.rng_read_and_skip( + resource=v.handle, + alg=1, + delta=constant_op.constant(1, dtype=dtypes.uint64), + ) + + self.assertDTensorEqual(expected, replicated_layout, got) + + def testStatelessRandomGetKeyCounter(self): + seed = constant_op.constant([7, 17], dtypes.int32) + + # TPU computation result is different from CPU computation. + # We force it to run on the TPU using tpu_strategy for TPU mesh + # so that we compare equal values. + @polymorphic_function.function + def tpu_fn(): + return gen_stateless_random_ops_v2.stateless_random_get_key_counter( + seed=seed + ) + + if self.mesh.device_type().upper() == 'TPU': + expected = get_tpu_strategy().run(tpu_fn) + else: + expected = gen_stateless_random_ops_v2.stateless_random_get_key_counter( + seed=seed + ) + + replicated_1d_layout = Layout.replicated(self.mesh, 1) + seed = numpy_util.pack_numpy(seed, replicated_1d_layout) + + got = gen_stateless_random_ops_v2.stateless_random_get_key_counter( + seed=seed + ) + self.assertDTensorEqual(expected[0], replicated_1d_layout, got[0]) + self.assertDTensorEqual(expected[1], replicated_1d_layout, got[1]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/save_restore_v2_test.py b/tensorflow/dtensor/python/tests/save_restore_v2_test.py new file mode 100644 index 00000000000000..e53f5003240da9 --- /dev/null +++ b/tensorflow/dtensor/python/tests/save_restore_v2_test.py @@ -0,0 +1,337 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import gc + +from absl.testing import parameterized + +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_variable +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.checkpoint import checkpoint +from tensorflow.python.checkpoint import checkpoint_management +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.module import module +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import stateless_random_ops +from tensorflow.python.platform import test + +Mesh = layout_lib.Mesh +Layout = layout_lib.Layout +UNSHARDED = layout_lib.UNSHARDED + +# Makes a 2D mesh with dimension X(2) and dimension Y(4). +_MESH_DIM_X = 'x' +_MESH_DIM_Y = 'y' +_DEVICE_IDS = test_util.create_device_ids_array((2, 4)) +_TWO_D_CPU_MESH = Mesh( + [_MESH_DIM_X, _MESH_DIM_Y], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2, 4), 'CPU'), +) +_TWO_D_TPU_MESH = Mesh( + [_MESH_DIM_X, _MESH_DIM_Y], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2, 4), 'TPU'), +) +_TWO_D_GPU_MESH = Mesh( + [_MESH_DIM_X, _MESH_DIM_Y], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2, 4), 'GPU'), +) + + +class DTensorSaveRestoreV2Test(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorSaveRestoreV2Test, self).setUp() + self.skipForDeviceType(['TPU'], + 'all tests require 8 TPU cores.', + unless_device_count_equals_to=8) + mesh_dict = { + 'CPU': _TWO_D_CPU_MESH, + 'GPU': _TWO_D_GPU_MESH, + 'TPU': _TWO_D_TPU_MESH, + } + self.mesh = self.configTestMesh(mesh_dict) + self.skipForTfrt( + 'b/235088250, DTensorCheckpointingV2 requires upcasting TF scalar ' + 'variables to replicated DTensor scalar variables, which is not ' + 'supported in TFRT.') + + @parameterized.named_parameters( + ('x_unsharded', [_MESH_DIM_X, UNSHARDED]), + ('unsharded_x', [UNSHARDED, _MESH_DIM_X]), + ('x_y', [_MESH_DIM_X, _MESH_DIM_Y]), + ('unsharded_unsharded', [UNSHARDED, UNSHARDED]), + ) + def test_checkpoint_simple(self, shard_spec): + tensor_a = stateless_random_ops.stateless_random_uniform( + shape=[4, 8], seed=[0, 1] + ) + tensor_b = stateless_random_ops.stateless_random_uniform( + shape=[2, 4], seed=[0, 1] + ) + + layout = Layout(shard_spec, self.mesh) + + dvariable_a = d_variable.DVariable(numpy_util.pack_numpy(tensor_a, layout)) + dvariable_b = d_variable.DVariable(numpy_util.pack_numpy(tensor_b, layout)) + + # Record a checkpoint with two dvariables. + ckpt = checkpoint.Checkpoint(a=dvariable_a, b=dvariable_b) + + saved_path = ckpt.save(self.get_temp_dir()) + + # Zero out the values of the DVariables so that we can restore + # and check that the values are restored to the initial random values. + dvariable_a.assign( + numpy_util.pack_numpy( + array_ops.zeros([4, 8], dtype=dtypes.float32), layout + ) + ) + dvariable_b.assign( + numpy_util.pack_numpy( + array_ops.zeros([2, 4], dtype=dtypes.float32), layout + ) + ) + + ckpt.restore(saved_path) + + self.assertDTensorEqual(tensor_a, layout, dvariable_a.read_value()) + self.assertDTensorEqual(tensor_b, layout, dvariable_b.read_value()) + + @parameterized.named_parameters( + ('x_unsharded', [_MESH_DIM_X, UNSHARDED]), + ('unsharded_x', [UNSHARDED, _MESH_DIM_X]), + ('x_y', [_MESH_DIM_X, _MESH_DIM_Y]), + ('unsharded_unsharded', [UNSHARDED, UNSHARDED]), + ) + def test_checkpoint_write(self, shard_spec): + tensor_a = stateless_random_ops.stateless_random_uniform( + shape=[4, 8], seed=[0, 1] + ) + tensor_b = stateless_random_ops.stateless_random_uniform( + shape=[2, 4], seed=[0, 1] + ) + + layout = Layout(shard_spec, self.mesh) + + dvariable_a = d_variable.DVariable(numpy_util.pack_numpy(tensor_a, layout)) + dvariable_b = d_variable.DVariable(numpy_util.pack_numpy(tensor_b, layout)) + + ckpt = checkpoint.Checkpoint(a=dvariable_a, b=dvariable_b) + + saved_path = ckpt.write(self.get_temp_dir()) + + dvariable_a.assign( + numpy_util.pack_numpy( + array_ops.zeros([4, 8], dtype=dtypes.float32), layout + ) + ) + dvariable_b.assign( + numpy_util.pack_numpy( + array_ops.zeros([2, 4], dtype=dtypes.float32), layout + ) + ) + + ckpt.restore(saved_path) + + self.assertDTensorEqual(tensor_a, layout, dvariable_a.read_value()) + self.assertDTensorEqual(tensor_b, layout, dvariable_b.read_value()) + + @parameterized.named_parameters( + ('x_unsharded', [_MESH_DIM_X, UNSHARDED]), + ('unsharded_x', [UNSHARDED, _MESH_DIM_X]), + ('x_y', [_MESH_DIM_X, _MESH_DIM_Y]), + ('unsharded_unsharded', [UNSHARDED, UNSHARDED]), + ) + def test_checkpoint_manager(self, shard_spec): + tensor_a = stateless_random_ops.stateless_random_uniform( + shape=[8, 16], seed=[0, 1] + ) + tensor_b = stateless_random_ops.stateless_random_uniform( + shape=[4, 4], seed=[0, 1] + ) + + layout = Layout(shard_spec, self.mesh) + + dvariable_a = d_variable.DVariable(numpy_util.pack_numpy(tensor_a, layout)) + dvariable_b = d_variable.DVariable(numpy_util.pack_numpy(tensor_b, layout)) + + # Record a checkpoint with two dvariables. + ckpt = checkpoint.Checkpoint(a=dvariable_a, b=dvariable_b) + + checkpoint_manager = checkpoint_management.CheckpointManager( + ckpt, self.get_temp_dir(), max_to_keep=None + ) + + saved_path = checkpoint_manager.save() + + # Zero out the values of the DVariables so that we can restore + # and check that the values are restored to the initial random values. + dvariable_a.assign( + numpy_util.pack_numpy( + array_ops.zeros([8, 16], dtype=dtypes.float32), layout + ) + ) + dvariable_b.assign( + numpy_util.pack_numpy( + array_ops.zeros([4, 4], dtype=dtypes.float32), layout + ) + ) + + ckpt.restore(saved_path) + + self.assertDTensorEqual(tensor_a, layout, dvariable_a.read_value()) + self.assertDTensorEqual(tensor_b, layout, dvariable_b.read_value()) + + @parameterized.named_parameters( + ('x_unsharded', [_MESH_DIM_X, UNSHARDED]), + ('unsharded_x', [UNSHARDED, _MESH_DIM_X]), + ('x_y', [_MESH_DIM_X, _MESH_DIM_Y]), + ('unsharded_unsharded', [UNSHARDED, UNSHARDED]), + ) + def test_checkpoint_restore_with_different_layout(self, shard_spec): + tensor_a = stateless_random_ops.stateless_random_uniform( + shape=[4, 8], seed=[0, 1] + ) + tensor_b = stateless_random_ops.stateless_random_uniform( + shape=[2, 4], seed=[0, 1] + ) + + layout = Layout(shard_spec, self.mesh) + + dvariable_a = d_variable.DVariable(numpy_util.pack_numpy(tensor_a, layout)) + dvariable_b = d_variable.DVariable(numpy_util.pack_numpy(tensor_b, layout)) + + # Record a checkpoint with two dvariables. + checkpoint_1 = checkpoint.Checkpoint(a=dvariable_a, b=dvariable_b) + + saved_path = checkpoint_1.save(self.get_temp_dir()) + + new_layout = Layout([_MESH_DIM_X, _MESH_DIM_Y], self.mesh) + + # Create new Dvariables, zero'd out with different layouts + # from the layouts we saved the tensors. + dvariable_a = d_variable.DVariable( + numpy_util.pack_numpy( + array_ops.zeros([4, 8], dtype=dtypes.float32), new_layout + ) + ) + dvariable_b = d_variable.DVariable( + numpy_util.pack_numpy( + array_ops.zeros([2, 4], dtype=dtypes.float32), new_layout + ) + ) + + checkpoint_2 = checkpoint.Checkpoint(a=dvariable_a, b=dvariable_b) + + checkpoint_2.restore(saved_path) + + self.assertDTensorEqual(tensor_a, new_layout, dvariable_a.read_value()) + self.assertDTensorEqual(tensor_b, new_layout, dvariable_b.read_value()) + + @parameterized.named_parameters( + ('x_unsharded', [_MESH_DIM_X, UNSHARDED]), + ('unsharded_x', [UNSHARDED, _MESH_DIM_X]), + ) + def test_checkpoint_in_a_train_loop(self, shard_dims): + # This test is a parallel test with save_restore_test's + # DTensorSaveRestoreTest.test_checkpoint + + class M(module.Module): + + # Pass in both replicated and sharded for better coverage. + def __init__(self, replicated_value, sharded_value): + # This is actually a DVariable. + self.r = d_variable.DVariable(replicated_value) + self.s = d_variable.DVariable(sharded_value) + + def __call__(self, x): + return math_ops.reduce_sum(x + self.r) + math_ops.reduce_sum(x + self.s) + + directory = self.get_temp_dir() + + sharded_np = np.arange(8).reshape((2, 4)).astype(np.float32) + replicated_np = np.arange(16).reshape((8, 2)).astype(np.float32) + + replicated_layout = Layout.replicated(self.mesh, rank=2) + one_d_sharded_layout = Layout(shard_dims, self.mesh) + + replicated_value = api.copy_to_mesh(replicated_np, replicated_layout) + replicated_zeros = api.copy_to_mesh( + np.zeros((8, 2)).astype(np.float32), replicated_layout + ) + + sharded_value = numpy_util.pack_numpy(sharded_np, one_d_sharded_layout) + sharded_zeros = numpy_util.pack_numpy( + np.zeros((2, 4)).astype(np.float32), one_d_sharded_layout) + + # Training loop that just increments the model's variable every "epoch" + # to test checkpointing. + for epoch in range(5): + m = M(replicated_value, sharded_value) + + ckpt = checkpoint.Checkpoint(model=m) + manager = checkpoint_management.CheckpointManager( + ckpt, directory=directory, max_to_keep=None + ) + + ckpt.restore(manager.latest_checkpoint) + + # Ensure that the variable is created + m(api.copy_to_mesh(1.0, Layout.replicated(self.mesh, rank=0))) + + self.assertDTensorEqual(epoch + replicated_np, replicated_layout, m.r) + self.assertDTensorEqual(epoch + sharded_np, one_d_sharded_layout, m.s) + + m.s.assign_add( + numpy_util.pack_numpy( + np.ones((2, 4), dtype=np.float32), one_d_sharded_layout)) + m.r.assign_add( + api.copy_to_mesh( + constant_op.constant(np.ones((8, 2), dtype=np.float32)), + replicated_layout, + ) + ) + + checkpoint_number = epoch + 1 + + stats1 = api._dtensor_device()._get_stats() + manager.save(checkpoint_number=checkpoint_number) + + gc.collect() + stats2 = api._dtensor_device()._get_stats() + keys = set(stats2.keys()) + keys.update(stats1.keys()) + diff = {k: stats2.get(k, 0) - stats1.get(k, 0) for k in keys} + diff = {k: v for k, v in diff.items() if v != 0} + + m.s.assign(sharded_zeros) + m.r.assign(replicated_zeros) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tests/sparse_test.py b/tensorflow/dtensor/python/tests/sparse_test.py new file mode 100644 index 00000000000000..b519da74e4ea57 --- /dev/null +++ b/tensorflow/dtensor/python/tests/sparse_test.py @@ -0,0 +1,141 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import parameterized +import numpy as np + +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +# Convenient constants to use for tests. +_BATCH_DIM = "batch" +_MESH_DIM_X = "x" + +# Shorter notation +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +class DTensorSPMDTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + self.skipForDeviceType(["GPU", "TPU"], + "SparseTensors only supported on CPU.") + + global_ids = test_util.create_device_ids_array((2, 2)) + local_ids = np.ravel(global_ids).tolist() + mesh_dict = { + device: Mesh( + [_BATCH_DIM, _MESH_DIM_X], + global_ids, + local_ids, + test_util.create_device_list((2, 2), device), + ) + for device in ("CPU", "GPU", "TPU") + } + self.mesh = self.configTestMesh(mesh_dict) + + @parameterized.parameters( + [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64] + ) + def testIdentityOpWithSparseTensorInputSimple(self, dtype): + inputs = array_ops.ones([6, 4], dtype=dtype) + layout = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + + @polymorphic_function.function + def f(x): + return array_ops.identity(x) + + self.assertDTensorEqual( + inputs, layout, + f(numpy_util.pack_numpy(inputs, layout, make_sparse=True))) + + @parameterized.product( + dtype=[dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64], + is_sparse_a=[True, False], + is_sparse_b=[True, False], + ) + def testIdentityOpWithSparseTensorInputComplex(self, dtype, is_sparse_a, + is_sparse_b): + inputs_a = array_ops.ones([2, 1], dtype=dtype) + inputs_b = array_ops.ones([32, 16], dtype=dtype) + + layout_a = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + layout_b = Layout.replicated(self.mesh, rank=2) + + @polymorphic_function.function + def f(x, y): + return array_ops.identity(x), array_ops.identity(y) + + got_a, got_b = f( + numpy_util.pack_numpy(inputs_a, layout_a, make_sparse=is_sparse_a), + numpy_util.pack_numpy(inputs_b, layout_b, make_sparse=is_sparse_b)) + + self.assertDTensorEqual(inputs_a, layout_a, got_a) + self.assertDTensorEqual(inputs_b, layout_b, got_b) + + def testMultipleIdentityOpFromOneSparseTensor(self): + inputs_a = array_ops.ones([2, 1]) + layout_a = Layout.batch_sharded(self.mesh, _BATCH_DIM, rank=2) + + @polymorphic_function.function + def f(x): + return array_ops.identity(x), array_ops.identity(x) + + got_a, got_b = f( + numpy_util.pack_numpy(inputs_a, layout_a, make_sparse=True)) + + self.assertDTensorEqual(inputs_a, layout_a, got_a) + self.assertDTensorEqual(inputs_a, layout_a, got_b) + + @parameterized.product( + is_sparse_a=[True, False], + is_sparse_b=[True, False], + shard_type=["Replicated", "Sharded"]) + def testSparseTensorDenseMatMul(self, is_sparse_a, is_sparse_b, shard_type): + inputs_a = array_ops.ones([16, 16]) + inputs_b = array_ops.ones([16, 16]) + + if shard_type == "Replicated": + layout_a = Layout.replicated(self.mesh, rank=2) + layout_b = Layout.replicated(self.mesh, rank=2) + else: + layout_a = Layout([_MESH_DIM_X, _BATCH_DIM], self.mesh) + layout_b = Layout(["unsharded", _MESH_DIM_X], self.mesh) + + expected = math_ops.matmul(inputs_a, inputs_b) + + @polymorphic_function.function + def f(x, y): + return math_ops.matmul(x, y) + + got = f( + numpy_util.pack_numpy(inputs_a, layout_a, make_sparse=is_sparse_a), + numpy_util.pack_numpy(inputs_b, layout_b, make_sparse=is_sparse_b)) + + self.assertDTensorEqual(expected, Layout.replicated(self.mesh, rank=2), got) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/dtensor/python/tests/tpu_device_assignment_test.py b/tensorflow/dtensor/python/tests/tpu_device_assignment_test.py new file mode 100644 index 00000000000000..08ece48382e52b --- /dev/null +++ b/tensorflow/dtensor/python/tests/tpu_device_assignment_test.py @@ -0,0 +1,889 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for TPU device assignment.""" + +from tensorflow.dtensor.python import accelerator_util +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python import numpy_util +from tensorflow.dtensor.python import tpu_util +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +Layout = layout_lib.Layout +Mesh = layout_lib.Mesh + + +class DeviceAssignmentTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + accelerator_util.initialize_accelerator_system('TPU') + + def tearDown(self): + accelerator_util.shutdown_accelerator_system() + super().tearDown() + + def _build_all_reduce_ring(self, core_locations): + permutation = tpu_util._build_all_reduce_ring(core_locations) + return [core_locations[element] for element in permutation] + + # Picture of chips: + # 0 -- 1 + # | | + # 3 -- 2 + def testBuildAllReduceRing4Replicas(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 0), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 0), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips with core0/core1 assignments: + # 0/1 -- 2/3 + # | | + # 6/7 -- 4/5 + def testBuildAllReduceRing8ReplicasUsingTwoCores(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 0 -- 1 -- 2 -- 3 + # | | + # 15 6 -- 5 -- 4 + # | | + # 14 7 -- 8 -- 9 + # | | + # 13-- 12-- 11-- 10 + def testBuildAllReduceRing32Replicas(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 7 -- 0 6 -- 5 + # | | + # 2 -- 1 3 -- 4 + def testBuildAllReduceRing3D(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + ] + expected = [ + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 31-- 0 -- 1 -- 2 30--29--28--27 + # | | + # 14 5 -- 4 -- 3 15 24--25--26 + # | | | | + # 13 6 -- 7 -- 8 16 23--22--21 + # | | | | + # 12-- 11-- 10-- 9 17--18--19--20 + def testBuildAllReduceRing3DLarge(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(2, 0, 1, 0), + tpu_util._CoreLocation(2, 0, 1, 1), + tpu_util._CoreLocation(3, 0, 1, 0), + tpu_util._CoreLocation(3, 0, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(2, 1, 1, 0), + tpu_util._CoreLocation(2, 1, 1, 1), + tpu_util._CoreLocation(3, 1, 1, 0), + tpu_util._CoreLocation(3, 1, 1, 1), + tpu_util._CoreLocation(0, 2, 1, 0), + tpu_util._CoreLocation(0, 2, 1, 1), + tpu_util._CoreLocation(1, 2, 1, 0), + tpu_util._CoreLocation(1, 2, 1, 1), + tpu_util._CoreLocation(2, 2, 1, 0), + tpu_util._CoreLocation(2, 2, 1, 1), + tpu_util._CoreLocation(3, 2, 1, 0), + tpu_util._CoreLocation(3, 2, 1, 1), + tpu_util._CoreLocation(0, 3, 1, 0), + tpu_util._CoreLocation(0, 3, 1, 1), + tpu_util._CoreLocation(1, 3, 1, 0), + tpu_util._CoreLocation(1, 3, 1, 1), + tpu_util._CoreLocation(2, 3, 1, 0), + tpu_util._CoreLocation(2, 3, 1, 1), + tpu_util._CoreLocation(3, 3, 1, 0), + tpu_util._CoreLocation(3, 3, 1, 1), + ] + expected = [ + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 2, 1, 1), + tpu_util._CoreLocation(0, 2, 1, 0), + tpu_util._CoreLocation(0, 3, 1, 1), + tpu_util._CoreLocation(0, 3, 1, 0), + tpu_util._CoreLocation(1, 3, 1, 1), + tpu_util._CoreLocation(1, 3, 1, 0), + tpu_util._CoreLocation(2, 3, 1, 1), + tpu_util._CoreLocation(2, 3, 1, 0), + tpu_util._CoreLocation(3, 3, 1, 1), + tpu_util._CoreLocation(3, 3, 1, 0), + tpu_util._CoreLocation(3, 2, 1, 1), + tpu_util._CoreLocation(3, 2, 1, 0), + tpu_util._CoreLocation(2, 2, 1, 1), + tpu_util._CoreLocation(2, 2, 1, 0), + tpu_util._CoreLocation(1, 2, 1, 1), + tpu_util._CoreLocation(1, 2, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(2, 1, 1, 1), + tpu_util._CoreLocation(2, 1, 1, 0), + tpu_util._CoreLocation(3, 1, 1, 1), + tpu_util._CoreLocation(3, 1, 1, 0), + tpu_util._CoreLocation(3, 0, 1, 1), + tpu_util._CoreLocation(3, 0, 1, 0), + tpu_util._CoreLocation(2, 0, 1, 1), + tpu_util._CoreLocation(2, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + ] + result = self._build_all_reduce_ring(core_locations) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 0 -- 1 4 -- 5 + # | | | | + # 3 -- 2 7 -- 6 + # + # 12-- 13 8 -- 9 + # | | | | + # 15-- 14 11-- 10 + def testBuildOrthogonalAllReduceRings(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + ] + result = tpu_util._build_orthogonal_rings( + core_locations, ring_size=8, rotate_ring_across_rings=False) + self.assertAllEqual(result, expected) + + # Picture of chips: + # 0 -- 1 12 -- 13 + # | | | | + # 3 -- 2 15 -- 14 + # + # 4 -- 5 8 -- 9 + # | | | | + # 7 -- 6 11-- 10 + def testBuildOrthogonalRotatedAllReduceRings(self): + core_locations = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + ] + result = tpu_util._build_orthogonal_rings( + core_locations, ring_size=8, rotate_ring_across_rings=True) + self.assertAllEqual(result, expected) + + # Create a 4x8 mesh on a 4x4 DF slice, disallowing splitting hosts. + def testCreateDFMeshNoSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 1, 2], [4, 4, 1, 2], ['core', 'y', 'z', 'x'], + can_split_host_across_rings=False, + ring_size=8) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + self.assertAllEqual(result, expected) + + # Create a 4x8 mesh on a 4x4 DF slice with at most 2, 2, 1, 2 devices from + # each dimension, disallowing splitting hosts. + def testCreateDFMeshWithRingBoundsNoSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 1, 2], [2, 2, 1, 2], ['core', 'x', 'y', 'z'], + can_split_host_across_rings=False, + ring_size=8) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + self.assertAllEqual(result, expected) + + # Create a 4x8 mesh on a 4x4 DF slice, allowing splitting hosts. + def testCreateDFMeshSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 1, 2], [4, 4, 1, 2], ['core', 'y', 'z', 'x'], + can_split_host_across_rings=True, + ring_size=8) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + ] + self.assertAllEqual(result, expected) + + # Create a 2x64 mesh on a 4x4x4 PF slice, allowing splitting hosts. + def testCreateMeshPFSplittingHosts(self): + result = tpu_util._enumerate_core_locations( + [4, 4, 4, 2], [4, 4, 4, 2], ['core', 'x', 'y', 'z'], + can_split_host_across_rings=True, + ring_size=64) + expected = [ + tpu_util._CoreLocation(0, 0, 0, 0), + tpu_util._CoreLocation(0, 0, 0, 1), + tpu_util._CoreLocation(1, 0, 0, 0), + tpu_util._CoreLocation(1, 0, 0, 1), + tpu_util._CoreLocation(2, 0, 0, 0), + tpu_util._CoreLocation(2, 0, 0, 1), + tpu_util._CoreLocation(3, 0, 0, 0), + tpu_util._CoreLocation(3, 0, 0, 1), + tpu_util._CoreLocation(0, 1, 0, 0), + tpu_util._CoreLocation(0, 1, 0, 1), + tpu_util._CoreLocation(1, 1, 0, 0), + tpu_util._CoreLocation(1, 1, 0, 1), + tpu_util._CoreLocation(2, 1, 0, 0), + tpu_util._CoreLocation(2, 1, 0, 1), + tpu_util._CoreLocation(3, 1, 0, 0), + tpu_util._CoreLocation(3, 1, 0, 1), + tpu_util._CoreLocation(0, 2, 0, 0), + tpu_util._CoreLocation(0, 2, 0, 1), + tpu_util._CoreLocation(1, 2, 0, 0), + tpu_util._CoreLocation(1, 2, 0, 1), + tpu_util._CoreLocation(2, 2, 0, 0), + tpu_util._CoreLocation(2, 2, 0, 1), + tpu_util._CoreLocation(3, 2, 0, 0), + tpu_util._CoreLocation(3, 2, 0, 1), + tpu_util._CoreLocation(0, 3, 0, 0), + tpu_util._CoreLocation(0, 3, 0, 1), + tpu_util._CoreLocation(1, 3, 0, 0), + tpu_util._CoreLocation(1, 3, 0, 1), + tpu_util._CoreLocation(2, 3, 0, 0), + tpu_util._CoreLocation(2, 3, 0, 1), + tpu_util._CoreLocation(3, 3, 0, 0), + tpu_util._CoreLocation(3, 3, 0, 1), + tpu_util._CoreLocation(0, 0, 1, 0), + tpu_util._CoreLocation(0, 0, 1, 1), + tpu_util._CoreLocation(1, 0, 1, 0), + tpu_util._CoreLocation(1, 0, 1, 1), + tpu_util._CoreLocation(2, 0, 1, 0), + tpu_util._CoreLocation(2, 0, 1, 1), + tpu_util._CoreLocation(3, 0, 1, 0), + tpu_util._CoreLocation(3, 0, 1, 1), + tpu_util._CoreLocation(0, 1, 1, 0), + tpu_util._CoreLocation(0, 1, 1, 1), + tpu_util._CoreLocation(1, 1, 1, 0), + tpu_util._CoreLocation(1, 1, 1, 1), + tpu_util._CoreLocation(2, 1, 1, 0), + tpu_util._CoreLocation(2, 1, 1, 1), + tpu_util._CoreLocation(3, 1, 1, 0), + tpu_util._CoreLocation(3, 1, 1, 1), + tpu_util._CoreLocation(0, 2, 1, 0), + tpu_util._CoreLocation(0, 2, 1, 1), + tpu_util._CoreLocation(1, 2, 1, 0), + tpu_util._CoreLocation(1, 2, 1, 1), + tpu_util._CoreLocation(2, 2, 1, 0), + tpu_util._CoreLocation(2, 2, 1, 1), + tpu_util._CoreLocation(3, 2, 1, 0), + tpu_util._CoreLocation(3, 2, 1, 1), + tpu_util._CoreLocation(0, 3, 1, 0), + tpu_util._CoreLocation(0, 3, 1, 1), + tpu_util._CoreLocation(1, 3, 1, 0), + tpu_util._CoreLocation(1, 3, 1, 1), + tpu_util._CoreLocation(2, 3, 1, 0), + tpu_util._CoreLocation(2, 3, 1, 1), + tpu_util._CoreLocation(3, 3, 1, 0), + tpu_util._CoreLocation(3, 3, 1, 1), + tpu_util._CoreLocation(0, 0, 2, 0), + tpu_util._CoreLocation(0, 0, 2, 1), + tpu_util._CoreLocation(1, 0, 2, 0), + tpu_util._CoreLocation(1, 0, 2, 1), + tpu_util._CoreLocation(2, 0, 2, 0), + tpu_util._CoreLocation(2, 0, 2, 1), + tpu_util._CoreLocation(3, 0, 2, 0), + tpu_util._CoreLocation(3, 0, 2, 1), + tpu_util._CoreLocation(0, 1, 2, 0), + tpu_util._CoreLocation(0, 1, 2, 1), + tpu_util._CoreLocation(1, 1, 2, 0), + tpu_util._CoreLocation(1, 1, 2, 1), + tpu_util._CoreLocation(2, 1, 2, 0), + tpu_util._CoreLocation(2, 1, 2, 1), + tpu_util._CoreLocation(3, 1, 2, 0), + tpu_util._CoreLocation(3, 1, 2, 1), + tpu_util._CoreLocation(0, 2, 2, 0), + tpu_util._CoreLocation(0, 2, 2, 1), + tpu_util._CoreLocation(1, 2, 2, 0), + tpu_util._CoreLocation(1, 2, 2, 1), + tpu_util._CoreLocation(2, 2, 2, 0), + tpu_util._CoreLocation(2, 2, 2, 1), + tpu_util._CoreLocation(3, 2, 2, 0), + tpu_util._CoreLocation(3, 2, 2, 1), + tpu_util._CoreLocation(0, 3, 2, 0), + tpu_util._CoreLocation(0, 3, 2, 1), + tpu_util._CoreLocation(1, 3, 2, 0), + tpu_util._CoreLocation(1, 3, 2, 1), + tpu_util._CoreLocation(2, 3, 2, 0), + tpu_util._CoreLocation(2, 3, 2, 1), + tpu_util._CoreLocation(3, 3, 2, 0), + tpu_util._CoreLocation(3, 3, 2, 1), + tpu_util._CoreLocation(0, 0, 3, 0), + tpu_util._CoreLocation(0, 0, 3, 1), + tpu_util._CoreLocation(1, 0, 3, 0), + tpu_util._CoreLocation(1, 0, 3, 1), + tpu_util._CoreLocation(2, 0, 3, 0), + tpu_util._CoreLocation(2, 0, 3, 1), + tpu_util._CoreLocation(3, 0, 3, 0), + tpu_util._CoreLocation(3, 0, 3, 1), + tpu_util._CoreLocation(0, 1, 3, 0), + tpu_util._CoreLocation(0, 1, 3, 1), + tpu_util._CoreLocation(1, 1, 3, 0), + tpu_util._CoreLocation(1, 1, 3, 1), + tpu_util._CoreLocation(2, 1, 3, 0), + tpu_util._CoreLocation(2, 1, 3, 1), + tpu_util._CoreLocation(3, 1, 3, 0), + tpu_util._CoreLocation(3, 1, 3, 1), + tpu_util._CoreLocation(0, 2, 3, 0), + tpu_util._CoreLocation(0, 2, 3, 1), + tpu_util._CoreLocation(1, 2, 3, 0), + tpu_util._CoreLocation(1, 2, 3, 1), + tpu_util._CoreLocation(2, 2, 3, 0), + tpu_util._CoreLocation(2, 2, 3, 1), + tpu_util._CoreLocation(3, 2, 3, 0), + tpu_util._CoreLocation(3, 2, 3, 1), + tpu_util._CoreLocation(0, 3, 3, 0), + tpu_util._CoreLocation(0, 3, 3, 1), + tpu_util._CoreLocation(1, 3, 3, 0), + tpu_util._CoreLocation(1, 3, 3, 1), + tpu_util._CoreLocation(2, 3, 3, 0), + tpu_util._CoreLocation(2, 3, 3, 1), + tpu_util._CoreLocation(3, 3, 3, 0), + tpu_util._CoreLocation(3, 3, 3, 1), + ] + self.assertAllEqual(result, expected) + + def testCreateMeshNoSplittingHostsUnfulfillable(self): + with self.assertRaises(ValueError): + tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_unfulfillable_without_splitting_hosts', + can_split_host_across_rings=False) + + def testCreateMeshWithDefaultOptions(self): + mesh = tpu_util.create_tpu_mesh(['x'], [2], 'mesh_with_default_options') + self.assertAllEqual(mesh.shape(), [2]) + self.assertEqual(mesh.num_local_devices(), 2) + + def testCreateMeshWithWrongShape(self): + with self.assertRaises(ValueError): + tpu_util.create_tpu_mesh(['x'], [1], 'mesh_with_wrong_shape') + + # Build rings for the batch dimension. + def testCreateMeshWithPositiveRingDims(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_positive_ring_dims', + ring_dims=1) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # Build rings for all non-batch dimensions. + def testCreateMeshWithNegativeRingDims(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y', 'z'], [1, 2, 1], + 'mesh_with_negative_ring_dims', + ring_dims=-2) + self.assertAllEqual(mesh.shape(), [1, 2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # Build single-core rings. + def testCreateMeshWithZeroRingDims(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_zero_ring_dims', + ring_dims=0) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + def testCreateMeshWithCustomAxes(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_custom_axes', + ring_axes=['x', 'z', 'y', 'core']) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # More cores (2 cores) on the first axis (core) than ring size (1). + def testCreateMeshWithDividedAxis(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_with_divided_axis', + ring_dims=-1, + ring_axes=['core', 'z', 'y', 'x']) + self.assertAllEqual(mesh.shape(), [2, 1]) + self.assertEqual(mesh.num_local_devices(), 2) + + # Both meshes should produce the same result despite different `ring_dim`. + def testCreateMultipleMeshes(self): + a = constant_op.constant([[0, 1], [2, 3]], dtype=dtypes.int32) + b_expected = math_ops.reduce_sum(a) + + mesh_1 = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], 'mesh_1', ring_dims=1) + a_1 = numpy_util.pack_numpy(a, Layout(['x', 'y'], mesh_1)) + b_1 = math_ops.reduce_sum(a_1) + self.assertDTensorEqual(b_expected, Layout.replicated(mesh_1, rank=0), b_1) + + mesh_2 = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_2', + ring_dims=-1) + a_2 = numpy_util.pack_numpy(a, Layout(['x', 'y'], mesh_2)) + b_2 = math_ops.reduce_sum(a_2) + self.assertDTensorEqual(b_expected, Layout.replicated(mesh_2, rank=0), b_2) + + def testCreateMeshWithEmptyName(self): + tpu_util.create_tpu_mesh(['x'], [2], '') + + def testCreateMeshWithExistingName(self): + tpu_util.create_tpu_mesh(['x'], [2], 'mesh_with_existing_name') + with self.assertRaises(ValueError): + tpu_util.create_tpu_mesh(['x'], [2], 'mesh_with_existing_name') + + def testGetDeviceIDs(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_to_get_device_ids') + self.assertAllEqual(tpu_util.get_device_ids(mesh), [0, 1]) + + def testGetDeviceLocations(self): + mesh = tpu_util.create_tpu_mesh(['x', 'y'], [2, 1], + 'mesh_to_get_device_locations') + self.assertAllEqual( + tpu_util.get_device_locations(mesh), [{ + 'x': 0, + 'y': 0 + }, { + 'x': 1, + 'y': 0 + }]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/dtensor/python/tpu_util.py b/tensorflow/dtensor/python/tpu_util.py index 317d49e37ab59b..a191eac8a6532c 100644 --- a/tensorflow/dtensor/python/tpu_util.py +++ b/tensorflow/dtensor/python/tpu_util.py @@ -142,7 +142,8 @@ def _shutdown_tpu_system(): def tpu_system_init_helper(task_id, num_tasks, num_devices, - use_tfrt_host_runtime=True): + use_tfrt_host_runtime=True, + use_megacore=False): """A helper function to initialize multi-client tpu system.""" @def_function.function @@ -156,6 +157,10 @@ def _set_global_tpu_array_fn(topology_proto): with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access my_core_ids = _tpu_init_fn() + + if use_megacore: + logging.info("Using TPU megacore") + my_core_ids = my_core_ids * 2 logging.info("TPU core IDs: %s", my_core_ids) # `my_core_ids` contains the IDs of TPU cores attached to this host. @@ -240,7 +245,7 @@ def _set_global_tpu_array_fn(topology_proto): return tpu_topology, device -def initialize_tpu_system(): +def initialize_tpu_system(use_megacore=False): """Initializes the TPU system.""" # Make sure the server change is fully propagated before attempting to run @@ -260,7 +265,8 @@ def initialize_tpu_system(): task_id, num_tasks, num_devices, - use_tfrt_host_runtime=use_tfrt_host_runtime) + use_tfrt_host_runtime=use_tfrt_host_runtime, + use_megacore=use_megacore) global _tpu_topology _tpu_topology = tpu_topology logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape, diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index dffdc7d1c386ac..1c1bc5a47fac85 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2515,6 +2515,22 @@ func BatchMatMulAdjY(value bool) BatchMatMulAttr { } } +// BatchMatMulGradX sets the optional grad_x attribute to value. +// If not specified, defaults to false +func BatchMatMulGradX(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["grad_x"] = value + } +} + +// BatchMatMulGradY sets the optional grad_y attribute to value. +// If not specified, defaults to false +func BatchMatMulGradY(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["grad_y"] = value + } +} + // Multiplies slices of two tensors in batches. // // Multiplies all slices of `Tensor` `x` and `y` (each slice can be @@ -2584,6 +2600,22 @@ func BatchMatMulV2AdjY(value bool) BatchMatMulV2Attr { } } +// BatchMatMulV2GradX sets the optional grad_x attribute to value. +// If not specified, defaults to false +func BatchMatMulV2GradX(value bool) BatchMatMulV2Attr { + return func(m optionalAttr) { + m["grad_x"] = value + } +} + +// BatchMatMulV2GradY sets the optional grad_y attribute to value. +// If not specified, defaults to false +func BatchMatMulV2GradY(value bool) BatchMatMulV2Attr { + return func(m optionalAttr) { + m["grad_y"] = value + } +} + // Multiplies slices of two tensors in batches. // // Multiplies all slices of `Tensor` `x` and `y` (each slice can be @@ -2657,6 +2689,22 @@ func BatchMatMulV3AdjY(value bool) BatchMatMulV3Attr { } } +// BatchMatMulV3GradX sets the optional grad_x attribute to value. +// If not specified, defaults to false +func BatchMatMulV3GradX(value bool) BatchMatMulV3Attr { + return func(m optionalAttr) { + m["grad_x"] = value + } +} + +// BatchMatMulV3GradY sets the optional grad_y attribute to value. +// If not specified, defaults to false +func BatchMatMulV3GradY(value bool) BatchMatMulV3Attr { + return func(m optionalAttr) { + m["grad_y"] = value + } +} + // Multiplies slices of two tensors in batches. // // Multiplies all slices of `Tensor` `x` and `y` (each slice can be @@ -24661,6 +24709,22 @@ func MatMulTransposeB(value bool) MatMulAttr { } } +// MatMulGradA sets the optional grad_a attribute to value. +// If not specified, defaults to false +func MatMulGradA(value bool) MatMulAttr { + return func(m optionalAttr) { + m["grad_a"] = value + } +} + +// MatMulGradB sets the optional grad_b attribute to value. +// If not specified, defaults to false +func MatMulGradB(value bool) MatMulAttr { + return func(m optionalAttr) { + m["grad_b"] = value + } +} + // Multiply the matrix "a" by the matrix "b". // // The inputs must be two-dimensional matrices and the inner dimension of diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 4c42cccbbb82b2..d1d3d6bb123958 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -220,7 +220,6 @@ cc_library( copts = tflite_copts_warnings(), deps = [ ":graph_info", - ":kernel_api", ":memory_planner", ":simple_memory_arena", ":util", @@ -237,7 +236,6 @@ cc_library( copts = tflite_copts_warnings() + ["-DTF_LITE_TENSORFLOW_PROFILER"], deps = [ ":graph_info", - ":kernel_api", ":memory_planner", ":simple_memory_arena_with_profiler", ":util", @@ -256,9 +254,10 @@ cc_test( ":arena_planner_with_profiler", ":builtin_ops", ":graph_info", - "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/testing:util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], ) @@ -1053,7 +1052,6 @@ cc_test( deps = [ ":simple_memory_arena", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/acceleration/configuration/BUILD b/tensorflow/lite/acceleration/configuration/BUILD index 4f1b2fe568cb97..6b6aab6638d033 100644 --- a/tensorflow/lite/acceleration/configuration/BUILD +++ b/tensorflow/lite/acceleration/configuration/BUILD @@ -277,11 +277,8 @@ cc_library( "//conditions:default": [], }), visibility = [ - "//tensorflow/lite/acceleration/configuration/c:__pkg__", "//tensorflow/lite/core/acceleration/configuration/c:__pkg__", - "//tensorflow/lite/core/experimental/acceleration/configuration/c:__pkg__", "//tensorflow/lite/experimental/acceleration/configuration:__pkg__", - "//tensorflow/lite/experimental/acceleration/configuration/c:__pkg__", ], deps = [ ":configuration_fbs", diff --git a/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h b/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h index 3c186b0345b741..46c20dbfb40f93 100644 --- a/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h +++ b/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ #define TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h + #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" #endif // TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ diff --git a/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h b/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h index 8a8202ef2bd1f6..6c83d2b2bded37 100644 --- a/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h +++ b/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ #define TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h + #include "tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h" #endif // TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ diff --git a/tensorflow/lite/acceleration/configuration/c/nnapi_plugin.h b/tensorflow/lite/acceleration/configuration/c/nnapi_plugin.h index 74be5f9b3a96dc..f2406e8311860b 100644 --- a/tensorflow/lite/acceleration/configuration/c/nnapi_plugin.h +++ b/tensorflow/lite/acceleration/configuration/c/nnapi_plugin.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_NNAPI_PLUGIN_H_ #define TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_NNAPI_PLUGIN_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.h + #include "tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.h" #endif // TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_NNAPI_PLUGIN_H_ diff --git a/tensorflow/lite/acceleration/configuration/c/stable_delegate.h b/tensorflow/lite/acceleration/configuration/c/stable_delegate.h index 2b34a32f4bf611..f3589c58cc9562 100644 --- a/tensorflow/lite/acceleration/configuration/c/stable_delegate.h +++ b/tensorflow/lite/acceleration/configuration/c/stable_delegate.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_STABLE_DELEGATE_H_ #define TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_STABLE_DELEGATE_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h + #include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" #endif // TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_STABLE_DELEGATE_H_ diff --git a/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h b/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h index 9ced18f3dc5a86..ae44009e4b816e 100644 --- a/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h +++ b/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ #define TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h + #include "tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h" #endif // TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ diff --git a/tensorflow/lite/acceleration/configuration/delegate_registry.h b/tensorflow/lite/acceleration/configuration/delegate_registry.h index b1064054f30d25..a6ed2b0636b937 100644 --- a/tensorflow/lite/acceleration/configuration/delegate_registry.h +++ b/tensorflow/lite/acceleration/configuration/delegate_registry.h @@ -15,7 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_ #define TENSORFLOW_LITE_ACCELERATION_CONFIGURATION_DELEGATE_REGISTRY_H_ -#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" +/// For documentation, see +/// third_party/tensorflow/lite/core/acceleration/configuration/delegate_registry.h + +#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" // IWYU pragma: export namespace tflite { namespace delegates { diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index b63c682d7cb046..8fd1a794369b50 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/graph_info.h" #include "tensorflow/lite/simple_memory_arena.h" @@ -42,6 +41,7 @@ ArenaPlanner::ArenaPlanner(TfLiteContext* context, : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment, subgraph_index), + has_nonpersistent_memory_(false), persistent_arena_(kDefaultArenaAlignment, subgraph_index), preserve_all_tensors_(preserve_all_tensors), tensor_alignment_(tensor_alignment), @@ -380,6 +380,7 @@ TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { TfLiteStatus ArenaPlanner::ReleaseNonPersistentMemory() { // Clear non-persistent arena's buffer. TF_LITE_ENSURE_STATUS(arena_.ReleaseBuffer()); + has_nonpersistent_memory_ = false; // Set data pointers for all non-persistent tensors to nullptr. TfLiteTensor* tensors = graph_info_->tensors(); for (int i = 0; i < static_cast(graph_info_->num_tensors()); ++i) { @@ -394,7 +395,8 @@ TfLiteStatus ArenaPlanner::ReleaseNonPersistentMemory() { TfLiteStatus ArenaPlanner::AcquireNonPersistentMemory() { // First commit arena_ to allocate underlying buffer. bool reallocated; - TF_LITE_ENSURE_STATUS(arena_.Commit(context_, &reallocated)); + TF_LITE_ENSURE_STATUS(arena_.Commit(&reallocated)); + has_nonpersistent_memory_ = true; // Resolve allocations for all tensors not on the persistent arena. TfLiteTensor* tensors = graph_info_->tensors(); for (int i = 0; i < static_cast(graph_info_->num_tensors()); ++i) { @@ -407,7 +409,7 @@ TfLiteStatus ArenaPlanner::AcquireNonPersistentMemory() { } bool ArenaPlanner::HasNonPersistentMemory() { - return arena_.GetBufferSize() != 0; + return has_nonpersistent_memory_; } void ArenaPlanner::DumpDebugInfo(const std::vector& execution_plan) const { @@ -424,9 +426,10 @@ void ArenaPlanner::GetAllocInfo(size_t* arena_size, TfLiteStatus ArenaPlanner::Commit(bool* reallocated) { bool arena_reallocated, persistent_arena_reallocated; - TF_LITE_ENSURE_STATUS(arena_.Commit(context_, &arena_reallocated)); + TF_LITE_ENSURE_STATUS(arena_.Commit(&arena_reallocated)); + has_nonpersistent_memory_ = true; TF_LITE_ENSURE_STATUS( - persistent_arena_.Commit(context_, &persistent_arena_reallocated)); + persistent_arena_.Commit(&persistent_arena_reallocated)); *reallocated = arena_reallocated; *reallocated |= persistent_arena_reallocated; return kTfLiteOk; diff --git a/tensorflow/lite/arena_planner.h b/tensorflow/lite/arena_planner.h index f8547c352a8fc5..f4644d15986fab 100644 --- a/tensorflow/lite/arena_planner.h +++ b/tensorflow/lite/arena_planner.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ARENA_PLANNER_H_ #define TENSORFLOW_LITE_ARENA_PLANNER_H_ +#include #include #include #include @@ -30,7 +31,6 @@ limitations under the License. namespace tflite { constexpr const int kDefaultArenaAlignment = 64; -struct AllocationInfo; // A memory planner that makes all the allocations using arenas. // @@ -141,6 +141,8 @@ class ArenaPlanner : public MemoryPlanner { // Raw memory buffer that is allocated for all temporary and graph outputs // that are declared kTfLiteArenaRw. SimpleMemoryArena arena_; + // True when the arena_ has allocated memory (Commit was called). + bool has_nonpersistent_memory_; // Raw memory buffer that is allocated for persistent tensors that are // declared as kTfLiteArenaRwPersistent. diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 2a434d734f0ec9..2021ac0797654c 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -25,11 +26,12 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/logging.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/graph_info.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { @@ -1079,10 +1081,10 @@ TEST_F(ArenaPlannerTest, SimpleProfilerTest) { SetGraph(&graph); Execute(0, graph.nodes().size() - 1); - EXPECT_EQ(gNumAlloc, 2); + EXPECT_EQ(gNumAlloc, 1); EXPECT_EQ(gNumDealloc, 0); Destroy(); - EXPECT_EQ(gNumDealloc, 2); + EXPECT_EQ(gNumDealloc, 1); } } // namespace diff --git a/tensorflow/lite/async/c/async_kernel.h b/tensorflow/lite/async/c/async_kernel.h index f9bc9dcd5866d9..d49c72f4a5342b 100644 --- a/tensorflow/lite/async/c/async_kernel.h +++ b/tensorflow/lite/async/c/async_kernel.h @@ -14,6 +14,7 @@ limitations under the License. /// For documentation, see /// third_party/tensorflow/lite/core/async/c/async_kernel.h. + #include "tensorflow/lite/core/async/c/async_kernel.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_ASYNC_C_ASYNC_KERNEL_H_ diff --git a/tensorflow/lite/async/c/async_signature_runner.h b/tensorflow/lite/async/c/async_signature_runner.h index 84ea7085cc6ced..7eacd0cb8ebfc1 100644 --- a/tensorflow/lite/async/c/async_signature_runner.h +++ b/tensorflow/lite/async/c/async_signature_runner.h @@ -14,6 +14,7 @@ limitations under the License. /// For documentation, see /// third_party/tensorflow/lite/core/async/c/async_signature_runner.h. + #include "tensorflow/lite/core/async/c/async_signature_runner.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_ASYNC_C_ASYNC_SIGNATURE_RUNNER_H_ diff --git a/tensorflow/lite/async/c/task.h b/tensorflow/lite/async/c/task.h index 0fa56b3358302d..891e4183f4514e 100644 --- a/tensorflow/lite/async/c/task.h +++ b/tensorflow/lite/async/c/task.h @@ -12,8 +12,10 @@ limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_LITE_ASYNC_C_TASK_H_ #define TENSORFLOW_LITE_ASYNC_C_TASK_H_ + /// For documentation, see /// third_party/tensorflow/lite/core/async/c/task.h. + #include "tensorflow/lite/core/async/c/task.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_ASYNC_C_TASK_H_ diff --git a/tensorflow/lite/async/c/types.h b/tensorflow/lite/async/c/types.h index a606c75536b5b1..6b509427111de3 100644 --- a/tensorflow/lite/async/c/types.h +++ b/tensorflow/lite/async/c/types.h @@ -11,7 +11,10 @@ limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_LITE_ASYNC_C_TYPES_H_ #define TENSORFLOW_LITE_ASYNC_C_TYPES_H_ + /// For documentation, see /// tensorflow/lite/core/async/c/types.h. + #include "tensorflow/lite/core/async/c/types.h" // IWYU pragma: export + #endif // TENSORFLOW_LITE_ASYNC_C_TYPES_H_ diff --git a/tensorflow/lite/async/interop/c/attribute_map.h b/tensorflow/lite/async/interop/c/attribute_map.h index c1b41b6292ccfc..7da44462e99a30 100644 --- a/tensorflow/lite/async/interop/c/attribute_map.h +++ b/tensorflow/lite/async/interop/c/attribute_map.h @@ -17,4 +17,4 @@ limitations under the License. #include "tensorflow/lite/core/async/interop/c/attribute_map.h" // IWYU pragma: export -#endif // TENSORFLOW_LITE_ASYNC_INTEROP_C_ATTRIBUTE_MAP_H_ \ No newline at end of file +#endif // TENSORFLOW_LITE_ASYNC_INTEROP_C_ATTRIBUTE_MAP_H_ diff --git a/tensorflow/lite/async/interop/c/constants.h b/tensorflow/lite/async/interop/c/constants.h index 07365bf9f41dd1..6b151dde5fd3bd 100644 --- a/tensorflow/lite/async/interop/c/constants.h +++ b/tensorflow/lite/async/interop/c/constants.h @@ -12,6 +12,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ASYNC_INTEROP_C_CONSTANTS_H_ #define TENSORFLOW_LITE_ASYNC_INTEROP_C_CONSTANTS_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/async/interop/c/constants.h + #include "tensorflow/lite/core/async/interop/c/constants.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_ASYNC_INTEROP_C_CONSTANTS_H_ diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index 7628e5ad1f9997..0606819288b6e5 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/c/builtin_op_data.h + #include "tensorflow/lite/core/c/builtin_op_data.h" #endif // TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/lite/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h index 2bf6add77f3c02..84cd4b030506af 100644 --- a/tensorflow/lite/c/c_api_experimental.h +++ b/tensorflow/lite/c/c_api_experimental.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_ #define TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/c/c_api_experimental.h + #include "tensorflow/lite/core/c/c_api_experimental.h" #endif // TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/lite/c/c_api_opaque.h b/tensorflow/lite/c/c_api_opaque.h index 0cafb763f83cdf..7e4d401a46466e 100644 --- a/tensorflow/lite/c/c_api_opaque.h +++ b/tensorflow/lite/c/c_api_opaque.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_C_C_API_OPAQUE_H_ #define TENSORFLOW_LITE_C_C_API_OPAQUE_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/c/c_api_opaque.h + #include "tensorflow/lite/core/c/c_api_opaque.h" #endif // TENSORFLOW_LITE_C_C_API_OPAQUE_H_ diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index f5a31d3f0cd88c..8a8b51331c476b 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -21,6 +21,9 @@ limitations under the License. /// interpreter and the operations are C. /// /// For documentation, see tensorflow/lite/core/c/common.h. +/// +/// See also c_api_opaque.h which has more ABI-stable variants of some of these +/// APIs. #ifndef TENSORFLOW_LITE_C_COMMON_H_ #define TENSORFLOW_LITE_C_COMMON_H_ diff --git a/tensorflow/lite/c/jni/jni_utils.h b/tensorflow/lite/c/jni/jni_utils.h index a425dcf4788f40..355b7a4a83bbf9 100644 --- a/tensorflow/lite/c/jni/jni_utils.h +++ b/tensorflow/lite/c/jni/jni_utils.h @@ -22,6 +22,14 @@ limitations under the License. extern "C" { #endif +/// Checks whether the TFLite API has been initialized, throwing a Java exception +/// otherwise. +/// +/// @param env The JNIEnv for the current thread (which has to be attached to the +/// JVM). +/// @return Whether or not the TFLite API has been initialized. If this method +/// returns false, no other JNI method should be called until the pending +/// exception has been handled (typically by returning to Java). bool TfLiteCheckInitializedOrThrow(JNIEnv* env); #ifdef __cplusplus diff --git a/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h index 900a2666934186..3f02a3fe267fc1 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h @@ -12,18 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// NOLINTBEGIN(whitespace/line_length) // WARNING: Users of TensorFlow Lite should not include this file directly, // but should instead include // "third_party/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. -// NOLINTEND(whitespace/line_length) +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ /// C API types for TF Lite delegate plugins. +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on + #include "tensorflow/lite/core/c/common.h" #ifdef __cplusplus @@ -32,7 +41,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup delegate_plugin tensorflow/lite/acceleration/configuration/c/delegate_plugin.h +/** \defgroup delegate_plugin lite/acceleration/configuration/c/delegate_plugin.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h index c30ce4dcdf4452..c1e42c935f974a 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// NOLINTBEGIN(whitespace/line_length) -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include // "third_party/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. -// NOLINTEND(whitespace/line_length) +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ @@ -32,6 +31,16 @@ limitations under the License. /// /// But to provide a C API to access the GPU delegate plugin, we do expose /// some functions, which are declared below. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/acceleration/configuration/c/gpu_plugin.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" @@ -41,7 +50,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup gpu_plugin tensorflow/lite/acceleration/configuration/c/gpu_plugin.h +/** \defgroup gpu_plugin lite/acceleration/configuration/c/gpu_plugin.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h index fce48ff8622288..d7c51a9b5afc7a 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// NOLINTBEGIN(whitespace/line_length) -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include // "third_party/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. -// NOLINTEND(whitespace/line_length) +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ @@ -32,6 +31,16 @@ limitations under the License. /// /// But to provide a C API to access the XNNPACK delegate plugin, we do expose /// some functions, which are declared below. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" @@ -41,7 +50,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup xnnpack_plugin tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h +/** \defgroup xnnpack_plugin lite/acceleration/configuration/c/xnnpack_plugin.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index f37e38a9c144fd..8b7f0e522acf21 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -918,6 +918,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } + case BuiltinOperator_STABLEHLO_PAD: { + return ParseStablehloPad(op, error_reporter, allocator, builtin_data); + } // TODO: skip param parsing for now since ops below don't have kernels case BuiltinOperator_STABLEHLO_SLICE: case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM: @@ -952,7 +955,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_STABLEHLO_IOTA: case BuiltinOperator_STABLEHLO_COMPARE: case BuiltinOperator_STABLEHLO_CONVERT: - case BuiltinOperator_STABLEHLO_PAD: case BuiltinOperator_STABLEHLO_DOT_GENERAL: case BuiltinOperator_STABLEHLO_SORT: case BuiltinOperator_STABLEHLO_WHILE: @@ -2123,7 +2125,8 @@ TfLiteStatus ParseStablehloReduceWindow(const Operator* op, const size_t rank = schema_params->window_dimensions()->size(); auto LoadAttr = [&error_reporter]( - auto& params_array, auto* const flatbuffer_vector, + int64_t* params_array, size_t params_array_size_bytes, + const flatbuffers::Vector* flatbuffer_vector, const char* attr_name, const size_t expected_size, const int64_t fill_value) -> TfLiteStatus { if (flatbuffer_vector && flatbuffer_vector->size()) { @@ -2136,7 +2139,7 @@ TfLiteStatus ParseStablehloReduceWindow(const Operator* op, return kTfLiteError; } TfLiteStatus status = FlatBufferIntVectorToArray( - sizeof(params_array), flatbuffer_vector, params_array, + params_array_size_bytes, flatbuffer_vector, params_array, error_reporter, "stablehlo.reduce_window"); if (status != kTfLiteOk) { TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.", @@ -2144,43 +2147,32 @@ TfLiteStatus ParseStablehloReduceWindow(const Operator* op, return status; } } else { - std::fill_n(params_array, - TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT, + std::fill_n(params_array, params_array_size_bytes / sizeof(int64_t), fill_value); } return kTfLiteOk; }; - if (TfLiteStatus status = LoadAttr( - params->window_dimensions, schema_params->window_dimensions(), - "window_dimensions", /*expected_size=*/rank, /*fill_value=*/1); - status != kTfLiteOk) { - return status; - } - if (TfLiteStatus status = LoadAttr( - params->window_strides, schema_params->window_strides(), - "window_strides", /*expected_size=*/rank, /*fill_value=*/1); - status != kTfLiteOk) { - return status; - } - if (TfLiteStatus status = LoadAttr( - params->base_dilations, schema_params->base_dilations(), - "base_dilations", /*expected_size=*/rank, /*fill_value=*/1); - status != kTfLiteOk) { - return status; - } - if (TfLiteStatus status = LoadAttr( - params->window_dilations, schema_params->window_dilations(), - "window_dilations", /*expected_size=*/rank, /*fill_value=*/1); - status != kTfLiteOk) { - return status; - } - if (TfLiteStatus status = - LoadAttr(params->padding, schema_params->padding(), "padding", - /*expected_size=*/2 * rank, /*fill_value=*/0); - status != kTfLiteOk) { - return status; - } + TF_LITE_ENSURE_STATUS( + LoadAttr(params->window_dimensions, sizeof(params->window_dimensions), + schema_params->window_dimensions(), "window_dimensions", + /*expected_size=*/rank, /*fill_value=*/1)); + TF_LITE_ENSURE_STATUS( + LoadAttr(params->window_strides, sizeof(params->window_strides), + schema_params->window_strides(), "window_strides", + /*expected_size=*/rank, /*fill_value=*/1)); + TF_LITE_ENSURE_STATUS( + LoadAttr(params->base_dilations, sizeof(params->base_dilations), + schema_params->base_dilations(), "base_dilations", + /*expected_size=*/rank, /*fill_value=*/1)); + TF_LITE_ENSURE_STATUS( + LoadAttr(params->window_dilations, sizeof(params->window_dilations), + schema_params->window_dilations(), "window_dilations", + /*expected_size=*/rank, /*fill_value=*/1)); + TF_LITE_ENSURE_STATUS(LoadAttr(params->padding, sizeof(params->padding), + schema_params->padding(), "padding", + /*expected_size=*/2 * rank, + /*fill_value=*/0)); params->body_subgraph_index = schema_params->body_subgraph_index(); *builtin_data = params.release(); @@ -2209,27 +2201,34 @@ TfLiteStatus ParseStablehloScatter(const Operator* op, if (schema_params) { params->indices_are_sorted = schema_params->indices_are_sorted(); - TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->update_window_dims()->size() * sizeof(int64_t), - schema_params->update_window_dims(), params->update_window_dims, - error_reporter, "stablehlo_scatter")); - params->num_update_window_dims = - schema_params->update_window_dims()->size(); + if (schema_params->update_window_dims()) { + TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->update_window_dims()->size() * sizeof(int64_t), + schema_params->update_window_dims(), params->update_window_dims, + error_reporter, "stablehlo_scatter")); + params->num_update_window_dims = + schema_params->update_window_dims()->size(); + } - TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->inserted_window_dims()->size() * sizeof(int64_t), - schema_params->inserted_window_dims(), params->inserted_window_dims, - error_reporter, "stablehlo_scatter")); - params->num_inserted_window_dims = - schema_params->inserted_window_dims()->size(); + if (schema_params->inserted_window_dims()) { + TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->inserted_window_dims()->size() * sizeof(int64_t), + schema_params->inserted_window_dims(), params->inserted_window_dims, + error_reporter, "stablehlo_scatter")); + params->num_inserted_window_dims = + schema_params->inserted_window_dims()->size(); + } - TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( - schema_params->scatter_dims_to_operand_dims()->size() * sizeof(int64_t), - schema_params->scatter_dims_to_operand_dims(), - params->scatter_dims_to_operand_dims, error_reporter, - "stablehlo_scatter")); - params->num_scatter_dims_to_operand_dims = - schema_params->scatter_dims_to_operand_dims()->size(); + if (schema_params->scatter_dims_to_operand_dims()) { + TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( + schema_params->scatter_dims_to_operand_dims()->size() * + sizeof(int64_t), + schema_params->scatter_dims_to_operand_dims(), + params->scatter_dims_to_operand_dims, error_reporter, + "stablehlo_scatter")); + params->num_scatter_dims_to_operand_dims = + schema_params->scatter_dims_to_operand_dims()->size(); + } params->index_vector_dim = schema_params->index_vector_dim(); params->unique_indices = schema_params->unique_indices(); @@ -2326,6 +2325,59 @@ TfLiteStatus ParseStablehloGather(const Operator* op, return kTfLiteOk; } +TfLiteStatus ParseStablehloPad(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + CheckParsePointerParams(op, error_reporter, allocator, builtin_data); + + SafeBuiltinDataAllocator safe_allocator(allocator); + auto params = safe_allocator.Allocate(); + const StablehloPadOptions* schema_params = + op->builtin_options_2_as_StablehloPadOptions(); + + if (schema_params) { + auto LoadAttr = + [&error_reporter]( + int64_t* params_array, const size_t params_array_size_bytes, + const flatbuffers::Vector* const flatbuffer_vector, + const char* const attr_name) -> TfLiteStatus { + TfLiteStatus status = FlatBufferIntVectorToArray( + params_array_size_bytes, flatbuffer_vector, params_array, + error_reporter, "stablehlo.pad"); + if (status != kTfLiteOk) { + TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.", + attr_name); + } + return status; + }; + + TF_LITE_ENSURE_STATUS( + LoadAttr(params->edge_padding_low, sizeof(params->edge_padding_low), + schema_params->edge_padding_low(), "edge_padding_low")); + TF_LITE_ENSURE_STATUS( + LoadAttr(params->edge_padding_high, sizeof(params->edge_padding_high), + schema_params->edge_padding_high(), "edge_padding_high")); + TF_LITE_ENSURE_STATUS( + LoadAttr(params->interior_padding, sizeof(params->interior_padding), + schema_params->interior_padding(), "interior_padding")); + if (schema_params->edge_padding_low()->size() != + schema_params->edge_padding_high()->size() || + schema_params->edge_padding_low()->size() != + schema_params->interior_padding()->size()) { + TF_LITE_REPORT_ERROR(error_reporter, + "'stablehlo.pad' operation parameter array sizes " + "are not consistent."); + return kTfLiteError; + } + *builtin_data = params.release(); + return kTfLiteOk; + } + TF_LITE_REPORT_ERROR(error_reporter, + "Could not get 'stablehlo.pad' operation parameters."); + return kTfLiteError; +} + // We have this parse function instead of directly returning kTfLiteOk from the // switch-case in ParseOpData because this function is used as part of the // selective registration for the OpResolver implementation in micro. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 11e70a601077de..1c90e9fd9bdd68 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -440,6 +440,11 @@ TfLiteStatus ParseStablehloReduceWindow(const Operator* op, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseStablehloPad(const Operator* op, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + } // namespace tflite #endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 1fbe440404607f..6e08e6880e5522 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -226,8 +225,7 @@ class StablehloReduceWindowFlatbufferConversionsTest auto EmptyAttr() { return builder_.CreateVector({}); } }; -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindow) { +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, Succeeds) { const Operator* stablehlo_reduce_window_op = BuildTestOperator( BuiltinOptions2_StablehloReduceWindowOptions, CreateStablehloReduceWindowOptions( @@ -260,40 +258,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowDeathTests) { - const Operator* stablehlo_reduce_window_op = BuildTestOperator( - BuiltinOptions2_StablehloReduceWindowOptions, - CreateStablehloReduceWindowOptions( - builder_, /*window_dimensions=*/ValidAttr(), - /*window_strides=*/ValidAttr(), - /*base_dilations=*/ValidAttr(), - /*window_dilations=*/ValidAttr(), - /*padding=*/ValidPaddingAttr(), /*body_subgraph_index=*/13) - .Union()); - TfLiteStablehloReduceWindowParams* output_data = nullptr; -#ifdef NDEBUG - GTEST_SKIP(); -#endif - EXPECT_DEATH( - ParseOpData(nullptr, BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_reporter_, &mock_allocator_, (void**)&output_data), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, nullptr, - &mock_allocator_, (void**)&output_data), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_reporter_, nullptr, (void**)&output_data), - ""); - EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, - BuiltinOperator_STABLEHLO_REDUCE_WINDOW, - &mock_reporter_, &mock_allocator_, nullptr), - ""); -} - -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWithNoWindowDimensions) { + FailsWithNoWindowDimensions) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -315,7 +280,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithNoWindowStrides) { + SucceedsWithNoWindowStrides) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -345,7 +310,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithNoBaseDilations) { + SucceedsWithNoBaseDilations) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -375,7 +340,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithNoWindowDilations) { + SucceedsWithNoWindowDilations) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -405,8 +370,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, EXPECT_THAT(output_data->body_subgraph_index, Eq(13)); } -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithNoPadding) { +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, SucceedsWithNoPadding) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -436,7 +400,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWithEmptyWindowDimensions) { + FailsWithEmptyWindowDimensions) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -458,7 +422,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithEmptyWindowStrides) { + SucceedsWithEmptyWindowStrides) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -488,7 +452,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithEmptyBaseDilations) { + SucceedsWithEmptyBaseDilations) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -518,7 +482,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithEmptyWindowDilations) { + SucceedsWithEmptyWindowDilations) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -549,7 +513,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithEmptyPadding) { + SucceedsWithEmptyPadding) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -579,7 +543,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowSucceedsWithParamsAtMaxDims) { + SucceedsWithParamsAtMaxDims) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -599,7 +563,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWhenWindowDimensionsHasMoreThanMaxDims) { + FailsWhenWindowDimensionsHasMoreThanMaxDims) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -622,7 +586,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWhenWindowStridesHasWrongDimCount) { + FailsWhenWindowStridesHasWrongDimCount) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -645,7 +609,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWhenBaseDilationsHasWrongDimCount) { + FailsWhenBaseDilationsHasWrongDimCount) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -668,7 +632,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWhenWindowDilationsHasWrongDimCount) { + FailsWhenWindowDilationsHasWrongDimCount) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -692,7 +656,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, } TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWhenPaddingHasWrongDimCount) { + FailsWhenPaddingHasWrongDimCount) { TfLiteStablehloReduceWindowParams* output_data = nullptr; EXPECT_EQ(ParseOpData( BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, @@ -713,8 +677,7 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, "not have the expected size")); } -TEST_F(StablehloReduceWindowFlatbufferConversionsTest, - ParseStablehloReduceWindowFailsWithWrongOptions) { +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, FailsWithWrongOptions) { const Operator* stablehlo_reduce_window_op = BuildTestOperator(BuiltinOptions2_StablehloReduceWindowOptions, 0); TfLiteStablehloReduceWindowParams* output_data = nullptr; @@ -729,4 +692,179 @@ TEST_F(StablehloReduceWindowFlatbufferConversionsTest, "Could not get 'stablehlo.reduce_window' operation parameters.")); } +TEST_F(StablehloReduceWindowFlatbufferConversionsTest, DeathTests) { + const Operator* stablehlo_reduce_window_op = BuildTestOperator( + BuiltinOptions2_StablehloReduceWindowOptions, + CreateStablehloReduceWindowOptions( + builder_, /*window_dimensions=*/ValidAttr(), + /*window_strides=*/ValidAttr(), + /*base_dilations=*/ValidAttr(), + /*window_dilations=*/ValidAttr(), + /*padding=*/ValidPaddingAttr(), /*body_subgraph_index=*/13) + .Union()); + TfLiteStablehloReduceWindowParams* output_data = nullptr; +#ifdef NDEBUG + GTEST_SKIP(); +#endif + EXPECT_DEATH( + ParseOpData(nullptr, BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, nullptr, + &mock_allocator_, (void**)&output_data), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_reporter_, nullptr, (void**)&output_data), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_reduce_window_op, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + &mock_reporter_, &mock_allocator_, nullptr), + ""); +} + +class StablehloPadFlatbufferConversionsTest : public FlatbufferConversionsTest { + public: + static constexpr int kMaxDims = + TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT; + static constexpr int64_t kValidValue = 5; +}; + +TEST_F(StablehloPadFlatbufferConversionsTest, Succeeds) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/builder_.CreateVector({3, 0, 3})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_EQ( + ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + kTfLiteOk); + EXPECT_THAT(std::make_tuple(output_data->edge_padding_low, 3), + ElementsAre(1, 0, -1)); + EXPECT_THAT(std::make_tuple(output_data->edge_padding_high, 3), + ElementsAre(2, 0, -2)); + EXPECT_THAT(std::make_tuple(output_data->interior_padding, 3), + ElementsAre(3, 0, 3)); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingLowPadding) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/0, + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/builder_.CreateVector({3, 0, 3})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_EQ( + ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + kTfLiteError); + EXPECT_THAT( + mock_reporter_.GetString(), + AllOf( + HasSubstr("Input array not provided for operation 'stablehlo.pad'."), + HasSubstr("Check the 'edge_padding_low' attribute."))); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingHighPadding) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/0, + /*interior_padding=*/builder_.CreateVector({3, 0, 3})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_EQ( + ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + kTfLiteError); + EXPECT_THAT( + mock_reporter_.GetString(), + AllOf( + HasSubstr("Input array not provided for operation 'stablehlo.pad'."), + HasSubstr("Check the 'edge_padding_high' attribute."))); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithMissingInteriorPadding) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/0) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_EQ( + ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + kTfLiteError); + EXPECT_THAT( + mock_reporter_.GetString(), + AllOf( + HasSubstr("Input array not provided for operation 'stablehlo.pad'."), + HasSubstr("Check the 'interior_padding' attribute."))); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsInconsistentSizes) { + const Operator* stablehlo_pad_op = BuildTestOperator( + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, + /*edge_padding_low=*/builder_.CreateVector({1, 0, -1}), + /*edge_padding_high=*/builder_.CreateVector({2, 0, -2}), + /*interior_padding=*/builder_.CreateVector({3, 0, -3, 5})) + .Union()); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_EQ( + ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + kTfLiteError); + EXPECT_THAT(mock_reporter_.GetString(), + HasSubstr("'stablehlo.pad' operation parameter array sizes are " + "not consistent.")); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, FailsWithWrongOptions) { + const Operator* stablehlo_pad_op = BuildTestOperator(BuiltinOptions_NONE, 0); + TfLiteStablehloPadParams* output_data = nullptr; + EXPECT_EQ( + ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, (void**)&output_data), + kTfLiteError); + EXPECT_THAT(mock_reporter_.GetString(), + HasSubstr("Could not get 'stablehlo.pad' operation parameters.")); +} + +TEST_F(StablehloPadFlatbufferConversionsTest, DeathTests) { + const Operator* stablehlo_pad_op = BuildTestOperator(BuiltinOptions_NONE, 0); + TfLiteStablehloPadParams* output_data = nullptr; +#ifdef NDEBUG + GTEST_SKIP(); +#endif + EXPECT_DEATH( + ParseOpData(nullptr, BuiltinOperator_STABLEHLO_PAD, &mock_reporter_, + &mock_allocator_, (void**)&output_data), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + nullptr, &mock_allocator_, (void**)&output_data), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, nullptr, (void**)&output_data), + ""); + EXPECT_DEATH(ParseOpData(stablehlo_pad_op, BuiltinOperator_STABLEHLO_PAD, + &mock_reporter_, &mock_allocator_, nullptr), + ""); +} + } // namespace tflite diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index e9e7fb17cf3936..70999f0b24cf7a 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -1,9 +1,9 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "//tensorflow/lite:build_def.bzl", "tflite_cc_library_with_c_headers_test", "tflite_copts", ) -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load( "//tensorflow/lite/core/c:special_rules.bzl", "c_api_experimental_visibility_allowlist", @@ -190,6 +190,33 @@ cc_test( ], ) +cc_test( + name = "c_api_test_with_opaque_delegate", + size = "small", + srcs = ["c_api_test.cc"], + copts = tflite_copts(), + data = [ + "//tensorflow/lite:testdata/2_subgraphs.bin", + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + "//tensorflow/lite:testdata/custom_sinh.bin", + ], + local_defines = ["TFLITE_USE_OPAQUE_DELEGATE"], + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_types", + ":common", + "//tensorflow/lite:string_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core:subgraph", + "//tensorflow/lite/delegates:delegate_test_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "selectively_built_c_api_test", size = "small", @@ -350,6 +377,7 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite:signature_runner", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/profiling/telemetry:profiler", @@ -387,6 +415,7 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite:util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/profiling/telemetry:profiler", @@ -438,6 +467,7 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite:util", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_opaque_internal_without_alwayslink", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/profiling/telemetry:profiler", @@ -563,6 +593,7 @@ cc_test( ":common", "//tensorflow/lite:kernel_api", "//tensorflow/lite:util", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/delegates:delegate_test_util", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/core/c/builtin_op_data.h b/tensorflow/lite/core/c/builtin_op_data.h index b96350f45e2af5..1ac385b932b15e 100644 --- a/tensorflow/lite/core/c/builtin_op_data.h +++ b/tensorflow/lite/core/c/builtin_op_data.h @@ -35,6 +35,7 @@ extern "C" { #define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 #define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 #define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 // TODO(aselle): Consider using "if this then that" for testing. @@ -636,6 +637,14 @@ typedef struct { enum TfLiteReduceWindowFunction reduce_function; } TfLiteReduceWindowParams; +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad + int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; +} TfLiteStablehloPadParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/c/c_api.h b/tensorflow/lite/core/c/c_api.h index b98fddf2569744..f7504a315f1bff 100644 --- a/tensorflow/lite/core/c/c_api.h +++ b/tensorflow/lite/core/c/c_api.h @@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// \warning Note: Users of TensorFlow Lite should not include this file -// directly, but should instead include -// "third_party/tensorflow/lite/c/c_api.h". Only the TensorFlow Lite -// implementation itself should include this -// file directly. +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. #ifndef TENSORFLOW_LITE_CORE_C_C_API_H_ #define TENSORFLOW_LITE_CORE_C_C_API_H_ @@ -76,6 +75,16 @@ limitations under the License. /// TfLiteInterpreterOptionsDelete(options); /// TfLiteModelDelete(model); /// +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/c/c_api.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on #ifdef __cplusplus extern "C" { @@ -83,7 +92,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup c_api tensorflow/lite/c/c_api.h +/** \defgroup c_api lite/c/c_api.h * @{ */ // NOLINTEND(whitespace/line_length) @@ -276,8 +285,6 @@ TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsAddRegistrationExternal( /// /// By default it is disabled and calling to `TfLiteInterpreterCancel` will /// return kTfLiteError. See `TfLiteInterpreterCancel`. -/// -/// \warning This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterOptionsEnableCancellation( TfLiteInterpreterOptions* options, bool enable); @@ -448,8 +455,6 @@ TfLiteTensor* TfLiteInterpreterGetTensor(const TfLiteInterpreter* interpreter, /// /// Returns kTfLiteError if cancellation is not enabled via /// `TfLiteInterpreterOptionsEnableCancellation`. -/// -/// \warning This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterCancel( const TfLiteInterpreter* interpreter); diff --git a/tensorflow/lite/core/c/c_api_experimental.cc b/tensorflow/lite/core/c/c_api_experimental.cc index f7e117a64c53e0..f88349efc9f7f3 100644 --- a/tensorflow/lite/core/c/c_api_experimental.cc +++ b/tensorflow/lite/core/c/c_api_experimental.cc @@ -17,12 +17,14 @@ limitations under the License. #include +#include #include #include #include #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/c_api.h" #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/profiling/telemetry/profiler.h" @@ -166,11 +168,6 @@ void TfLiteInterpreterOptionsSetEnableDelegateFallback( options->enable_delegate_fallback = enable; } -void TfLiteSetAllowBufferHandleOutput(const TfLiteInterpreter* interpreter, - bool allow_buffer_handle_output) { - interpreter->impl->SetAllowBufferHandleOutput(allow_buffer_handle_output); -} - TfLiteStatus TfLiteInterpreterModifyGraphWithDelegate( const TfLiteInterpreter* interpreter, TfLiteDelegate* delegate) { return interpreter->impl->ModifyGraphWithDelegate(delegate); @@ -191,6 +188,26 @@ int32_t TfLiteInterpreterGetSignatureCount( return static_cast(interpreter->impl->signature_keys().size()); } +TfLiteStatus TfLiteInterpreterSetBufferHandle(TfLiteInterpreter* interpreter, + TfLiteTensor* tensor, + TfLiteBufferHandle buffer_handle, + TfLiteOpaqueDelegate* delegate) { + return interpreter->impl->SetBufferHandle(tensor, buffer_handle, delegate); +} + +TfLiteStatus TfLiteInterpreterGetBufferHandle(TfLiteInterpreter* interpreter, + int tensor_index, + TfLiteBufferHandle* buffer_handle, + TfLiteOpaqueDelegate** delegate) { + return interpreter->impl->GetBufferHandle(tensor_index, buffer_handle, + delegate); +} + +void TfLiteSetAllowBufferHandleOutput(const TfLiteInterpreter* interpreter, + bool allow_buffer_handle_output) { + interpreter->impl->SetAllowBufferHandleOutput(allow_buffer_handle_output); +} + TfLiteStatus TfLiteInterpreterSetCustomAllocationForTensor( TfLiteInterpreter* interpreter, int tensor_index, const TfLiteCustomAllocation* allocation, int64_t flags) { @@ -201,6 +218,11 @@ TfLiteStatus TfLiteInterpreterSetCustomAllocationForTensor( *allocation, flags); } +TfLiteStatus TfLiteInterpreterEnsureTensorDataIsReadable( + TfLiteInterpreter* interpreter, int tensor_index) { + return interpreter->impl->EnsureTensorDataIsReadable(tensor_index); +} + const char* TfLiteInterpreterGetSignatureKey( const TfLiteInterpreter* interpreter, int32_t signature_index) { int32_t signature_count = TfLiteInterpreterGetSignatureCount(interpreter); diff --git a/tensorflow/lite/core/c/c_api_experimental.h b/tensorflow/lite/core/c/c_api_experimental.h index 95528110167369..9de042491ead7d 100644 --- a/tensorflow/lite/core/c/c_api_experimental.h +++ b/tensorflow/lite/core/c/c_api_experimental.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_LITE_CORE_C_C_API_EXPERIMENTAL_H_ #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/c_api.h" #include "tensorflow/lite/core/c/common.h" @@ -266,17 +267,6 @@ TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI( TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetEnableDelegateFallback( TfLiteInterpreterOptions* options, bool enable); -// Set if buffer handle output is allowed. -// -/// When using hardware delegation, Interpreter will make the data of output -/// tensors available in `tensor->data` by default. If the application can -/// consume the buffer handle directly (e.g. reading output from OpenGL -/// texture), it can set this flag to false, so Interpreter won't copy the -/// data from buffer handle to CPU memory. WARNING: This is an experimental -/// API and subject to change. -TFL_CAPI_EXPORT extern void TfLiteSetAllowBufferHandleOutput( - const TfLiteInterpreter* interpreter, bool allow_buffer_handle_output); - /// Allow a delegate to look at the graph and modify the graph to handle /// parts of the graph themselves. After this is called, the graph may /// contain new nodes that replace 1 more nodes. @@ -332,6 +322,41 @@ TfLiteInterpreterSetCustomAllocationForTensor( TfLiteInterpreter* interpreter, int tensor_index, const TfLiteCustomAllocation* allocation, int64_t flags); +/// -------------------------------------------------------------------------- +/// BufferHandle APIs + +/// Sets the delegate buffer handle for the given tensor. +/// +/// This function sets the buffer handle for a tensor that is used by other +/// computing hardware such as EdgeTpu. For example, EdgeTpu delegate imports a +/// tensor's memory into EdgeTpu's virtual address and returns a buffer handle. +/// Then EdgeTpu delegate calls this API to associate the tensor with the buffer +/// handle. +/// +/// WARNING: This is an experimental API and subject to change. +TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterSetBufferHandle( + TfLiteInterpreter* interpreter, TfLiteTensor* tensor, + TfLiteBufferHandle buffer_handle, TfLiteOpaqueDelegate* delegate); + +/// Gets the delegate buffer handle, and the delegate which can process +/// the buffer handle. +/// +/// WARNING: This is an experimental API and subject to change. +TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterGetBufferHandle( + TfLiteInterpreter* interpreter, int tensor_index, + TfLiteBufferHandle* buffer_handle, TfLiteOpaqueDelegate** delegate); + +/// Sets whether buffer handle output is allowed. +/// When using hardware delegation, Interpreter will make the data of output +/// tensors available in `tensor->data` by default. If the application can +/// consume the buffer handle directly (e.g. reading output from OpenGL +/// texture), it can set this flag to false, so Interpreter won't copy the +/// data from buffer handle to CPU memory. +/// +/// WARNING: This is an experimental API and subject to change. +TFL_CAPI_EXPORT extern void TfLiteSetAllowBufferHandleOutput( + const TfLiteInterpreter* interpreter, bool allow_buffer_handle_output); + /// -------------------------------------------------------------------------- /// SignatureRunner APIs @@ -360,6 +385,16 @@ TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetTelemetryProfiler( TfLiteInterpreterOptions* options, struct TfLiteTelemetryProfilerStruct* profiler); +/// Ensures the data of the tensor at the given index is readable. +/// Note: If a delegate has been used, and `SetAllowBufferHandleOutput(true)` +/// has been called, tensor outputs may be stored as delegate buffer handles +/// whose data is not directly readable until this method has been called. In +/// such cases, this method will copy the data from the delegate buffer handle +/// to CPU memory. +/// +/// WARNING: This is an experimental API and subject to change. +TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterEnsureTensorDataIsReadable( + TfLiteInterpreter* interpreter, int tensor_index); #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/c/c_api_experimental_test.cc b/tensorflow/lite/core/c/c_api_experimental_test.cc index 2425dd97676ba3..c52ad232b6a9e7 100644 --- a/tensorflow/lite/core/c/c_api_experimental_test.cc +++ b/tensorflow/lite/core/c/c_api_experimental_test.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/lite/core/c/c_api_experimental.h" +#include #include #include #include +#include #include #include #include @@ -25,6 +27,7 @@ limitations under the License. #include #include #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/c_api.h" #include "tensorflow/lite/core/c/c_api_opaque.h" #include "tensorflow/lite/core/c/common.h" @@ -566,6 +569,356 @@ TEST(CApiExperimentalTest, SetCustomAllocationForOutputTensorSuccess) { TfLiteModelDelete(model); } +TEST(CApiExperimentalTest, SetAndGetBufferHandleSuccess) { + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + auto simple_delegate = std::make_unique( + // The delegate will handle the first (index 0) and the second (index 1) + // op nodes in the TfLiteModel. + /*nodes=*/std::vector({0, 1}), + /*delegate_flags=*/kTfLiteDelegateFlagsNone, + /*fail_node_prepare=*/false, /*min_ops_per_subset=*/0, + /*fail_node_invoke=*/false, + /* automatic_shape_propagation=*/false, /*custom_op=*/false, + /* set_output_tensor_dynamic =*/false); + TfLiteDelegate* delegate = simple_delegate->get_tf_lite_delegate(); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsAddDelegate(options, delegate); + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + ASSERT_NE(interpreter, nullptr); + EXPECT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + // Tensor index is set to the input tensor (index 1) of the TfLiteModel. + int tensor_index = 1; + TfLiteTensor* tensor = TfLiteInterpreterGetTensor(interpreter, tensor_index); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + ASSERT_EQ(tensor->delegate, nullptr); + + // Use of an arbitrary non-negative int value for the buffer handle. + TfLiteBufferHandle buffer_handle = 1234; + + TfLiteDelegate* expected_delegate = delegate; + TfLiteBufferHandle expected_buffer_handle = buffer_handle; + ASSERT_EQ(TfLiteInterpreterSetBufferHandle(interpreter, tensor, buffer_handle, + delegate), + kTfLiteOk); + ASSERT_EQ(tensor->delegate, expected_delegate); + ASSERT_EQ(tensor->buffer_handle, expected_buffer_handle); + + TfLiteOpaqueDelegate* fetched_delegate; + TfLiteBufferHandle fetched_buffer_handle; + ASSERT_EQ( + TfLiteInterpreterGetBufferHandle( + interpreter, tensor_index, &fetched_buffer_handle, &fetched_delegate), + kTfLiteOk); + ASSERT_EQ(fetched_delegate, expected_delegate); + ASSERT_EQ(fetched_buffer_handle, expected_buffer_handle); + + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + +// A utility struct, intended to be used to record the interaction between a +// test delegate and the runtime. +struct DelegateState { + bool delegate_prepared; + bool copy_from_buffer_handle_called; + bool free_buffer_handle_called; + int buffer_handle; + + void Reset() { + delegate_prepared = false; + copy_from_buffer_handle_called = false; + free_buffer_handle_called = false; + buffer_handle = -1; + } +}; + +struct OpaqueTestDelegate { + static constexpr int kTestDelegateOutput = 42; + + static inline TfLiteStatus Prepare(TfLiteOpaqueContext* opaque_context, + TfLiteOpaqueDelegate* opaque_delegate, + void* data) { + DelegateState* delegate_state = reinterpret_cast(data); + delegate_state->delegate_prepared = true; + + // The buffer handle is set to one greater than the last allocated buffer + // handle. + delegate_state->buffer_handle++; + + TfLiteRegistration registration{}; + registration.registration_external = TfLiteRegistrationExternalCreate( + kTfLiteBuiltinDelegate, "OpaqueTestDelegate delegate kernel", + /* version = */ 1); + + TfLiteRegistrationExternalSetPrepare( + registration.registration_external, + [](TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; }); + + TfLiteRegistrationExternalSetInvoke( + registration.registration_external, + [](TfLiteOpaqueContext*, TfLiteOpaqueNode*) -> TfLiteStatus { + return kTfLiteOk; + }); + + TfLiteIntArray* execution_plan; + TfLiteOpaqueContextGetExecutionPlan(opaque_context, &execution_plan); + + TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels( + opaque_context, registration.registration_external, execution_plan, + opaque_delegate); + return kTfLiteOk; + } + + static TfLiteStatus CopyFromBufferHandle(TfLiteOpaqueContext* context, + TfLiteOpaqueDelegate* delegate, + void* data, + TfLiteBufferHandle buffer_handle, + TfLiteOpaqueTensor* opaque_tensor) { + DelegateState* delegate_state = reinterpret_cast(data); + delegate_state->copy_from_buffer_handle_called = true; + delegate_state->buffer_handle = buffer_handle; + + auto* output = + reinterpret_cast(TfLiteOpaqueTensorData(opaque_tensor)); + int total_num_elements = 1; + for (int i = 0; i < TfLiteOpaqueTensorNumDims(opaque_tensor); ++i) { + total_num_elements *= TfLiteOpaqueTensorDim(opaque_tensor, i); + } + std::vector meaning_of_life(total_num_elements, kTestDelegateOutput); + memcpy(output, meaning_of_life.data(), + meaning_of_life.size() * sizeof(float)); + return kTfLiteOk; + } + + static inline void FreeBufferHandle(TfLiteOpaqueContext* context, + TfLiteOpaqueDelegate* delegate, + void* data, + TfLiteBufferHandle* buffer_handle) { + DelegateState* delegate_state = reinterpret_cast(data); + delegate_state->free_buffer_handle_called = true; + } +}; + +TEST(CApiExperimentalTest, SetAllowBufferHandleOutputFalse) { + DelegateState delegate_state; + delegate_state.Reset(); + + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + int kNumTensorElements = 1 * 8 * 8 * 3; + + TfLiteOpaqueDelegateBuilder opaque_delegate_builder{}; + opaque_delegate_builder.data = &delegate_state; + opaque_delegate_builder.CopyFromBufferHandle = + OpaqueTestDelegate::CopyFromBufferHandle; + opaque_delegate_builder.FreeBufferHandle = + OpaqueTestDelegate::FreeBufferHandle; + opaque_delegate_builder.Prepare = OpaqueTestDelegate::Prepare; + + TfLiteOpaqueDelegate* tflite_delegate = + TfLiteOpaqueDelegateCreate(&opaque_delegate_builder); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsAddDelegate(options, tflite_delegate); + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + ASSERT_NE(interpreter, nullptr); + + // Allocate tensor buffers. + EXPECT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + // Fill input buffers + TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0); + float* input = reinterpret_cast(input_tensor->data.raw); + std::fill(input, input + kNumTensorElements, 1); + + // We set the buffer handle of the output tensor and mark its data as stale. + // This will make the interpreter call 'CopyFromBufferHandle' to refresh the + // output tensor's data. + int first_buffer_handle = 0; + + // Tensor index is set to the output tensor (index 2) of the TfLite model. + int tensor_index = 2; + + TfLiteTensor* output_tensor = + TfLiteInterpreterGetTensor(interpreter, tensor_index); + + ASSERT_EQ( + TfLiteInterpreterSetBufferHandle(interpreter, output_tensor, + first_buffer_handle, tflite_delegate), + kTfLiteOk); + + output_tensor->data_is_stale = true; + + TfLiteSetAllowBufferHandleOutput(interpreter, + /*allow_buffer_handle_output=*/false); + + // Run inference + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(delegate_state.delegate_prepared); + EXPECT_TRUE(delegate_state.copy_from_buffer_handle_called); + EXPECT_EQ(delegate_state.buffer_handle, first_buffer_handle); + EXPECT_FALSE(delegate_state.free_buffer_handle_called); + float* outputs = reinterpret_cast(output_tensor->data.raw); + for (int i = 0; i < kNumTensorElements; ++i) { + EXPECT_EQ(outputs[i], OpaqueTestDelegate::kTestDelegateOutput); + } + ASSERT_EQ(output_tensor->buffer_handle, first_buffer_handle); + ASSERT_EQ(output_tensor->delegate, tflite_delegate); + + // Destroying the interpreter will release any buffer handles that are + // associated with the tensors owner by the interpreter. + delegate_state.Reset(); + TfLiteInterpreterDelete(interpreter); + TfLiteOpaqueDelegateDelete(tflite_delegate); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); + EXPECT_FALSE(delegate_state.copy_from_buffer_handle_called); + EXPECT_TRUE(delegate_state.free_buffer_handle_called); +} + +TEST(CApiExperimentalTest, SetAllowBufferHandleOutputTrue) { + DelegateState delegate_state; + delegate_state.Reset(); + + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + int kNumTensorElements = 1 * 8 * 8 * 3; + + TfLiteOpaqueDelegateBuilder opaque_delegate_builder{}; + opaque_delegate_builder.data = &delegate_state; + opaque_delegate_builder.CopyFromBufferHandle = + OpaqueTestDelegate::CopyFromBufferHandle; + opaque_delegate_builder.FreeBufferHandle = + OpaqueTestDelegate::FreeBufferHandle; + opaque_delegate_builder.Prepare = OpaqueTestDelegate::Prepare; + + TfLiteOpaqueDelegate* tflite_delegate = + TfLiteOpaqueDelegateCreate(&opaque_delegate_builder); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsAddDelegate(options, tflite_delegate); + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + ASSERT_NE(interpreter, nullptr); + + // Allocate tensor buffers. + EXPECT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + + // Fill input buffers + TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0); + float* input = reinterpret_cast(input_tensor->data.raw); + std::fill(input, input + kNumTensorElements, 1); + + // We set the buffer handle of the output tensor and mark its data as stale. + // This will make the interpreter call 'CopyFromBufferHandle' to refresh the + // output tensor's data. + EXPECT_FALSE(delegate_state.free_buffer_handle_called); + int first_buffer_handle = 0; + + // Tensor index is set to the output tensor (index 2) of the TfLite model. + int tensor_index = 2; + + TfLiteTensor* output_tensor = + TfLiteInterpreterGetTensor(interpreter, tensor_index); + + ASSERT_EQ( + TfLiteInterpreterSetBufferHandle(interpreter, output_tensor, + first_buffer_handle, tflite_delegate), + kTfLiteOk); + + output_tensor->data_is_stale = true; + + TfLiteSetAllowBufferHandleOutput(interpreter, + /*allow_buffer_handle_output=*/true); + + // Run inference + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(delegate_state.delegate_prepared); + EXPECT_FALSE(delegate_state.copy_from_buffer_handle_called); + EXPECT_EQ(delegate_state.buffer_handle, first_buffer_handle); + EXPECT_FALSE(delegate_state.free_buffer_handle_called); + ASSERT_EQ(output_tensor->buffer_handle, first_buffer_handle); + ASSERT_EQ(output_tensor->delegate, tflite_delegate); + + // Destroying the interpreter will release any buffer handles that are + // associated with the tensors owner by the interpreter. + delegate_state.Reset(); + TfLiteInterpreterDelete(interpreter); + TfLiteOpaqueDelegateDelete(tflite_delegate); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); + EXPECT_FALSE(delegate_state.copy_from_buffer_handle_called); + EXPECT_TRUE(delegate_state.free_buffer_handle_called); +} + +TEST(CApiExperimentalTest, SetInvalidHandleToTensor) { + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + auto simple_delegate = std::make_unique( + // The delegate will handle the first (index 0) and the second (index 1) + // op nodes in the TfLiteModel. + /*nodes=*/std::vector({0, 1}), + /*delegate_flags=*/kTfLiteDelegateFlagsNone, + /*fail_node_prepare=*/false, /*min_ops_per_subset=*/0, + /*fail_node_invoke=*/false, + /* automatic_shape_propagation=*/false, /*custom_op=*/false, + /* set_output_tensor_dynamic =*/false); + TfLiteDelegate* delegate = simple_delegate->get_tf_lite_delegate(); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsAddDelegate(options, delegate); + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + ASSERT_NE(interpreter, nullptr); + + EXPECT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + + auto another_simple_delegate = std::make_unique( + // The delegate will handle the 0th, 1st and the 2nd indexed nodes in + // the TfLiteModel. + /*nodes=*/std::vector({0, 1, 2}), + /*delegate_flags=*/kTfLiteDelegateFlagsNone, + /*fail_node_prepare=*/false, /*min_ops_per_subset=*/0, + /*fail_node_invoke=*/false, /* automatic_shape_propagation=*/false, + /*custom_op=*/false, /*set_output_tensor_dynamic=*/false); + + // Tensor index is set to the output tensor (index 2) of the TfLite model. + int tensor_index = 2; + TfLiteTensor* tensor = TfLiteInterpreterGetTensor(interpreter, tensor_index); + + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + // Buffer handle is set to one greater than the last allocated buffer handle. + TfLiteBufferHandle buffer_handle = kTfLiteNullBufferHandle + 1; + + // Setting a buffer handle to a tensor with another delegate will fail. + ASSERT_EQ(TfLiteInterpreterSetBufferHandle( + interpreter, tensor, buffer_handle, + another_simple_delegate->get_tf_lite_delegate()), + kTfLiteError); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + void AllocateAndSetInputs(TfLiteInterpreter* interpreter) { std::array input_dims = {2}; ASSERT_EQ(TfLiteInterpreterResizeInputTensor( diff --git a/tensorflow/lite/core/c/c_api_opaque.cc b/tensorflow/lite/core/c/c_api_opaque.cc index 13cf85cfb967bb..926d0a4714fef3 100644 --- a/tensorflow/lite/core/c/c_api_opaque.cc +++ b/tensorflow/lite/core/c/c_api_opaque.cc @@ -147,7 +147,10 @@ size_t TfLiteOpaqueTensorByteSize(const TfLiteOpaqueTensor* opaque_tensor) { } void* TfLiteOpaqueTensorData(const TfLiteOpaqueTensor* opaque_tensor) { - return TfLiteTensorData(reinterpret_cast(opaque_tensor)); + return opaque_tensor != nullptr + ? TfLiteTensorData( + reinterpret_cast(opaque_tensor)) + : nullptr; } TfLiteAllocationType TfLiteOpaqueTensorGetAllocationType( diff --git a/tensorflow/lite/core/c/c_api_opaque.h b/tensorflow/lite/core/c/c_api_opaque.h index 06bdc194b221f6..0a012bfebae087 100644 --- a/tensorflow/lite/core/c/c_api_opaque.h +++ b/tensorflow/lite/core/c/c_api_opaque.h @@ -12,6 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api_opaque.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_C_C_API_OPAQUE_H_ #define TENSORFLOW_LITE_CORE_C_C_API_OPAQUE_H_ @@ -36,10 +41,20 @@ extern "C" { /// potentially including non-backwards-compatible changes, on a different /// schedule than for the other TensorFlow Lite APIs. See /// https://www.tensorflow.org/guide/versions#separate_version_number_for_tensorflow_lite_extension_apis. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/c/c_api_opaque.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup c_api_opaque tensorflow/lite/c/c_api_opaque.h +/** \defgroup c_api_opaque lite/c/c_api_opaque.h * @{ */ // NOLINTEND(whitespace/line_length) @@ -93,6 +108,7 @@ TFL_CAPI_EXPORT extern size_t TfLiteOpaqueTensorByteSize( const TfLiteOpaqueTensor* opaque_tensor); /// Returns a pointer to the underlying data buffer. +/// Returns nullptr if input is also nullptr. TFL_CAPI_EXPORT extern void* TfLiteOpaqueTensorData( const TfLiteOpaqueTensor* opaque_tensor); diff --git a/tensorflow/lite/core/c/c_api_opaque_test.cc b/tensorflow/lite/core/c/c_api_opaque_test.cc index ab2c00b2604a4f..f59a35c3c0feb4 100644 --- a/tensorflow/lite/core/c/c_api_opaque_test.cc +++ b/tensorflow/lite/core/c/c_api_opaque_test.cc @@ -166,6 +166,18 @@ TEST(TestTfLiteOpaqueTensorGetBufferAddressStability, TfLiteTensorGetBufferAddressStability(&t)); } +TEST(TestTfLiteOpaqueTensorData, ValidInput) { + TfLiteTensor t; + char data[] = "data"; + t.data.raw = data; + EXPECT_EQ(TfLiteOpaqueTensorData(reinterpret_cast(&t)), + data); +} + +TEST(TestTfLiteOpaqueTensorData, NullInput) { + EXPECT_EQ(TfLiteOpaqueTensorData(nullptr), nullptr); +} + TEST(TestTfLiteOpaqueTensorGetDataStability, WithMemNoneBehavesAsTfLiteTensorGetDataStability) { TfLiteTensor t; diff --git a/tensorflow/lite/core/c/c_api_test.cc b/tensorflow/lite/core/c/c_api_test.cc index abb0083e12578c..189cd9815f8ebf 100644 --- a/tensorflow/lite/core/c/c_api_test.cc +++ b/tensorflow/lite/core/c/c_api_test.cc @@ -291,6 +291,7 @@ TEST(CApiSimple, TfLiteInterpreterGetTensor) { TfLiteInterpreterDelete(interpreter); } +#if !TFLITE_USE_OPAQUE_DELEGATE TEST(CApiSimple, Delegate) { TfLiteModel* model = TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); @@ -316,6 +317,7 @@ TEST(CApiSimple, Delegate) { EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); TfLiteInterpreterDelete(interpreter); } +#endif TEST(CApiSimple, DelegateExternal_GetExecutionPlan) { TfLiteModel* model = @@ -409,6 +411,7 @@ TEST(CApiSimple, DelegateExternal_MarkSubgraphAsDelegationSkippable) { TfLiteOpaqueDelegateDelete(opaque_delegate); } +#if !TFLITE_USE_OPAQUE_DELEGATE TEST(CApiSimple, DelegateFails) { TfLiteModel* model = TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); @@ -428,6 +431,7 @@ TEST(CApiSimple, DelegateFails) { TfLiteInterpreterOptionsDelete(options); TfLiteModelDelete(model); } +#endif struct DelegateState { bool delegate_prepared; diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index c1f0c568fcf04a..1170025cbab9a2 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -12,16 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api_types.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. /// This file declares types used by the pure C inference API defined in /// c_api.h, some of which are also used in the C++ and C kernel and interpreter /// APIs. - -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include -// "third_party/tensorflow/lite/c/c_api_types.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/c/c_api_types.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on // IWYU pragma: private, include "third_party/tensorflow/lite/c/c_api_types.h" @@ -36,7 +44,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup c_api_types tensorflow/lite/c/c_api_types.h +/** \defgroup c_api_types lite/c/c_api_types.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 0ebba76e948f33..ca29104f203954 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/common.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. /// This file defines common C types and APIs for implementing operations, /// delegates and other constructs in TensorFlow Lite. The actual operations and @@ -32,12 +36,17 @@ limitations under the License. /// /// NOTE: The order of values in these structs are "semi-ABI stable". New values /// should be added only to the end of structs and never reordered. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/c/common.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include -// "third_party/tensorflow/lite/c/common.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. // IWYU pragma: private, include "third_party/tensorflow/lite/c/common.h" #ifndef TENSORFLOW_LITE_CORE_C_COMMON_H_ @@ -56,7 +65,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup common tensorflow/lite/c/common.h +/** \defgroup common lite/c/common.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index ee9748c031e87d..5c2917e8be9f24 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -225,8 +225,8 @@ TfLiteStatus Interpreter::Invoke() { ScopedRuntimeInstrumentationProfile scoped_runtime_event(root_profiler_.get(), "invoke"); - // "Resets" cancellation flag so cancellation happens before this invoke will - // not take effect. + // "Resets" cancellation flag so cancellation that happens before this invoke + // will not take effect. if (cancellation_enabled_) (void)continue_invocation_.test_and_set(); // Denormal floating point numbers could cause significant slowdown on diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index 98a2fd67f4da90..ed9d798f34753b 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -580,6 +580,7 @@ class Interpreter { /// 5. kTfLiteError: Unexpected/runtime failure. \n /// \warning This is an experimental API and subject to change. \n TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); + TfLiteStatus ModifyGraphWithDelegate(TfLiteOpaqueDelegateStruct* delegate); // Owning handle to a TfLiteDelegate instance. using TfLiteDelegatePtr = @@ -611,9 +612,12 @@ class Interpreter { std::unique_ptr delegate) = delete; /// \warning This is an experimental API and subject to change. \n - /// \brief Ensure the data in `tensor.data` is readable. In case delegate is - /// used, it might require to copy the data from delegate buffer to raw - /// memory. + /// \brief Ensure the data in `tensor.data` is readable. If a + /// delegate has been used, and `SetAllowBufferHandleOutput(true)` has been + /// called, tensor outputs may be stored as delegate buffer handles whose data + /// is not directly readable until this method has been called. + /// In such cases, this method will copy the data from the delegate buffer + /// handle to CPU memory. TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { return primary_subgraph().EnsureTensorDataIsReadable(tensor_index); } diff --git a/tensorflow/lite/core/interpreter_builder.h b/tensorflow/lite/core/interpreter_builder.h index fcdae0b2543de8..6233d6561ab29c 100644 --- a/tensorflow/lite/core/interpreter_builder.h +++ b/tensorflow/lite/core/interpreter_builder.h @@ -76,7 +76,7 @@ class InterpreterBuilder { /// For this constructor, the ErrorReporter will be extracted from the /// FlatBufferModel. /// `options` object is copied during construction. So caller can release it - // after calling the constructor. + /// after calling the constructor. InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver, const InterpreterOptions* options_experimental = nullptr); @@ -84,7 +84,7 @@ class InterpreterBuilder { /// of a FlatBufferModel). Mostly used for testing. /// If `error_reporter` is null, then DefaultErrorReporter() is used. /// `options` object is copied during construction. So caller can release it - // after calling the constructor. + /// after calling the constructor. InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter(), diff --git a/tensorflow/lite/core/interpreter_experimental.cc b/tensorflow/lite/core/interpreter_experimental.cc index e04b1d3e7c675d..016d45df977955 100644 --- a/tensorflow/lite/core/interpreter_experimental.cc +++ b/tensorflow/lite/core/interpreter_experimental.cc @@ -84,6 +84,12 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { return ModifyGraphWithDelegateImpl(delegate); } +TfLiteStatus Interpreter::ModifyGraphWithDelegate( + TfLiteOpaqueDelegateStruct* delegate) { + return ModifyGraphWithDelegateImpl( + reinterpret_cast(delegate)); +} + bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); } TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, diff --git a/tensorflow/lite/core/kernels/builtin_op_kernels.h b/tensorflow/lite/core/kernels/builtin_op_kernels.h index 20362ada18e65c..e0dcbf8d4b0605 100644 --- a/tensorflow/lite/core/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/core/kernels/builtin_op_kernels.h @@ -291,9 +291,9 @@ Register_STABLEHLO_DYNAMIC_SLICE(); // WARNING: not implemented, using this TfLiteRegistration* Register_STABLEHLO_DYNAMIC_UPDATE_SLICE(); // WARNING: not implemented, using // this op will crash the runtime -TfLiteRegistration* -Register_STABLEHLO_PAD(); // WARNING: not implemented, using this - // op will crash the runtime + +TfLiteRegistration* Register_STABLEHLO_PAD(); + TfLiteRegistration* Register_STABLEHLO_IOTA(); // WARNING: not implemented, using this // op will crash the runtime diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index cb53c20558106e..0e3eacf4d65017 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -297,7 +297,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 5); AddBuiltin(BuiltinOperator_WHERE, Register_WHERE(), /* min_version = */ 1, /* max_version = */ 2); @@ -380,6 +380,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_STABLEHLO_MULTIPLY, Register_STABLEHLO_MULTIPLY()); AddBuiltin(BuiltinOperator_STABLEHLO_MAXIMUM, Register_STABLEHLO_MAXIMUM()); AddBuiltin(BuiltinOperator_STABLEHLO_MINIMUM, Register_STABLEHLO_MINIMUM()); + AddBuiltin(BuiltinOperator_STABLEHLO_PAD, Register_STABLEHLO_PAD()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/core/model_builder.cc b/tensorflow/lite/core/model_builder.cc index 05822832cc2f93..e044c6da7e65c9 100644 --- a/tensorflow/lite/core/model_builder.cc +++ b/tensorflow/lite/core/model_builder.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include +#include #include #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/verifier.h" @@ -386,6 +387,12 @@ std::map FlatBufferModel::ReadAllMetadata( } bool FlatBufferModel::CheckModelIdentifier() const { + if (allocation_->bytes() < 7) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Model provided must have at least 7 bytes to hold identifier.\n"); + return false; + } if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); TF_LITE_REPORT_ERROR( diff --git a/tensorflow/lite/create_op_resolver.h b/tensorflow/lite/create_op_resolver.h index 41012171b07b03..853505f1c786e6 100644 --- a/tensorflow/lite/create_op_resolver.h +++ b/tensorflow/lite/create_op_resolver.h @@ -15,9 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_ #define TENSORFLOW_LITE_CREATE_OP_RESOLVER_H_ +/// For documentation, see third_party/tensorflow/lite/core/create_op_resolver.h + #include -#include "tensorflow/lite/core/create_op_resolver.h" +#include "tensorflow/lite/core/create_op_resolver.h" // IWYU pragma: export namespace tflite { using ::tflite::CreateOpResolver; diff --git a/tensorflow/lite/delegates/coreml/BUILD b/tensorflow/lite/delegates/coreml/BUILD index 08466b7ac02360..2868fa658b216a 100644 --- a/tensorflow/lite/delegates/coreml/BUILD +++ b/tensorflow/lite/delegates/coreml/BUILD @@ -29,12 +29,11 @@ objc_library( srcs = ["coreml_executor.mm"], hdrs = ["coreml_executor.h"], copts = ["-std=c++17"], - features = ["-layering_check"], sdk_frameworks = [ "CoreML", "Foundation", ], - deps = [":mlmodel_proto_cc"], + deps = ["@coremltools//:mlmodel_cc_proto"], ) cc_library( diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index 078aa0863a7d55..560b2b4c65b940 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -51,7 +51,8 @@ using test_utils::TestTwoDelegates; namespace { TEST_F(TestDelegate, NullDelegate) { - EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(nullptr), + TfLiteOpaqueDelegate* delegate = nullptr; + EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(delegate), kTfLiteDelegateError); } @@ -178,14 +179,14 @@ TEST_F(TestDelegate, SetBufferHandleToInput) { TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); interpreter_->ModifyGraphWithDelegate(delegate); - constexpr int kOutputTensorIndex = 0; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + constexpr int kInputTensorIndex = 0; + TfLiteTensor* tensor = interpreter_->tensor(kInputTensorIndex); ASSERT_EQ(tensor->delegate, nullptr); ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); TfLiteBufferHandle handle = AllocateBufferHandle(); TfLiteStatus status = - interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + interpreter_->SetBufferHandle(kInputTensorIndex, handle, delegate); ASSERT_EQ(status, kTfLiteOk); EXPECT_EQ(tensor->delegate, delegate); EXPECT_EQ(tensor->buffer_handle, handle); @@ -1488,7 +1489,8 @@ TEST_P(TestFP16Delegation, NonDelegatedInterpreterWorks) { } TEST_F(TestFP16Delegation, NullDelegate) { - EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(nullptr), + TfLiteOpaqueDelegate* delegate = nullptr; + EXPECT_EQ(interpreter_->ModifyGraphWithDelegate(delegate), kTfLiteDelegateError); // Verify that resulting interpreter still works, despite null delegate. ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 6e767beb635cfb..bd47abe905729c 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -1,12 +1,12 @@ load("@bazel_skylib//lib:selects.bzl", "selects") -load("//tensorflow/lite:special_rules.bzl", "tflite_extra_gles_deps", "tflite_portable_test_suite") -load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gpu_delegate_linkopts") -load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework") -load("@build_bazel_rules_apple//apple:macos.bzl", "macos_dylib") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_gpu_tests_tags", ) +load("//tensorflow/lite:special_rules.bzl", "tflite_extra_gles_deps", "tflite_portable_test_suite") +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gpu_delegate_linkopts") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework") +load("@build_bazel_rules_apple//apple:macos.bzl", "macos_dylib") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -33,19 +33,6 @@ config_setting( }, ) -# copybara:uncomment_begin(google-only) -# config_setting( -# name = "tflite_gpu_angle", -# flag_values = { -# "//tools/cpp:cc_target_os": "linux-google", -# "//third_party/angle:use_angle": "True", -# }, -# values = { -# "cpu": "k8", -# }, -# ) -# copybara:uncomment_end - cc_library( name = "gl_delegate", srcs = ["gl_delegate.cc"], @@ -92,7 +79,6 @@ objc_library( srcs = ["metal_delegate.mm"], hdrs = ["metal_delegate.h"], copts = ["-std=c++17"], - features = ["-layering_check"], module_name = "TensorFlowLiteCMetal", sdk_frameworks = ["Metal"], deps = [ @@ -108,11 +94,13 @@ objc_library( "//tensorflow/lite/delegates/gpu/common:quantization_util", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/metal:buffer_convert", + "//tensorflow/lite/delegates/gpu/metal:common", "//tensorflow/lite/delegates/gpu/metal:inference_context", "//tensorflow/lite/delegates/gpu/metal:metal_spatial_tensor", + "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:span", ], ) @@ -172,7 +160,7 @@ ios_static_framework( "metal_delegate.h", "metal_delegate_internal.h", ], - minimum_os_version = "11.4", + minimum_os_version = "12.0", deps = [":metal_delegate"], ) @@ -184,7 +172,7 @@ macos_dylib( "-all_load", "-dead_strip", ], - minimum_os_version = "10.13", + minimum_os_version = "12.0", tags = [ "manual", "nobuilder", @@ -269,6 +257,7 @@ cc_library( ], "//conditions:default": [], }) + [ + ":android_hardware_buffer", ":api", ":delegate_options", ":tflite_profile", @@ -304,11 +293,27 @@ cc_library( ], ) +cc_library( + name = "android_hardware_buffer", + srcs = ["android_hardware_buffer.cc"], + hdrs = ["android_hardware_buffer.h"], +) + +cc_test( + name = "android_hardware_buffer_test", + srcs = ["android_hardware_buffer_test.cc"], + deps = [ + ":android_hardware_buffer", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "async_buffers", srcs = ["async_buffers.cc"], hdrs = ["async_buffers.h"], deps = [ + ":android_hardware_buffer", ":api", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/gl:gl_errors", @@ -326,6 +331,7 @@ cc_test( "tflite_not_portable_ios", ], deps = [ + ":android_hardware_buffer", ":async_buffers", ":delegate", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/delegates/gpu/android_hardware_buffer.cc b/tensorflow/lite/delegates/gpu/android_hardware_buffer.cc new file mode 100644 index 00000000000000..e9bf3040b8f72a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/android_hardware_buffer.cc @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" + +#include + +namespace tflite::gpu { + +OptionalAndroidHardwareBuffer::OptionalAndroidHardwareBuffer() { +#ifdef __ANDROID__ + dlopen_handle_ = dlopen("libnativewindow.so", RTLD_NOW); + if (dlopen_handle_ == nullptr) { + supported_ = false; + return; + } + allocate_ = reinterpret_cast( + dlsym(dlopen_handle_, "AHardwareBuffer_allocate")); + acquire_ = reinterpret_cast( + dlsym(dlopen_handle_, "AHardwareBuffer_acquire")); + release_ = reinterpret_cast( + dlsym(dlopen_handle_, "AHardwareBuffer_release")); + describe_ = reinterpret_cast( + dlsym(dlopen_handle_, "AHardwareBuffer_describe")); + is_supported_ = reinterpret_cast( + dlsym(dlopen_handle_, "AHardwareBuffer_isSupported")); + supported_ = + (allocate_ != nullptr && acquire_ != nullptr && release_ != nullptr && + describe_ != nullptr && is_supported_ != nullptr); +#else + dlopen_handle_ = nullptr; + allocate_ = nullptr; + acquire_ = nullptr; + release_ = nullptr; + describe_ = nullptr; + is_supported_ = nullptr; + supported_ = false; +#endif +} + +} // namespace tflite::gpu diff --git a/tensorflow/lite/delegates/gpu/android_hardware_buffer.h b/tensorflow/lite/delegates/gpu/android_hardware_buffer.h new file mode 100644 index 00000000000000..dc272f6975ca06 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/android_hardware_buffer.h @@ -0,0 +1,130 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_ANDROID_HARDWARE_BUFFER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_ANDROID_HARDWARE_BUFFER_H_ + +#include + +#ifdef __ANDROID__ +#include +#else +extern "C" { +typedef struct AHardwareBuffer AHardwareBuffer; + +// struct is a copy of the Android NDK AHardwareBuffer_Desc struct in the link +// below +// https://developer.android.com/ndk/reference/struct/a-hardware-buffer-desc +typedef struct AHardwareBuffer_Desc AHardwareBuffer_Desc; +struct AHardwareBuffer_Desc { + uint32_t width; + uint32_t height; + uint32_t layers; + uint32_t format; + uint64_t usage; + uint32_t stride; + uint32_t rfu0; + uint64_t rfu1; +}; +} // extern "C" +#endif // __ANDROID__ + +namespace tflite::gpu { + +// This header file and singleton class encapsulates the following Android NDK +// features +// - header +// - opaque struct type AHardwareBuffer +// - struct type AHardwareBuffer_Desc +// - function AHardwareBuffer_isSupported +// - function AHardwareBuffer_allocate +// - function AHardwareBuffer_acquire +// - function AHardwareBuffer_release +// - function AHardwareBuffer_describe +// - library libnativewindow.so (for the above features) +// +// For documentation on these features, see +// : +// +// Unlike using the native NDK functionality directly, this class only has a +// run-time dependency on API level 26, not a build-time dependency. So it can +// be used even when building with NDK min SDK level < 26, as long as you are +// very careful to check that Supported() returns true before calling any other +// methods. +class OptionalAndroidHardwareBuffer { + public: + static OptionalAndroidHardwareBuffer& Instance() { + static OptionalAndroidHardwareBuffer instance; + return instance; + } + + // Returns true if the functionality in this class is supported. + bool Supported() { return supported_; } + + // Like AHardwareBuffer_isSupported. + // Caller must check that Supported() returns true before calling this + // function. + int IsSupported(const AHardwareBuffer_Desc* description) { + return is_supported_(description); + } + + // Like AHardwareBuffer_allocate. + // Caller must check that Supported() returns true before calling this + // function. + int Allocate(const AHardwareBuffer_Desc* description, + AHardwareBuffer** buffer) { + return allocate_(description, buffer); + } + + // Like AHardwareBuffer_acquire. + // Caller must check that Supported() returns true before calling this + // function. + void Acquire(AHardwareBuffer* buffer) { return acquire_(buffer); } + + // Like AHardwareBuffer_release. + // Caller must check that Supported() returns true before calling this + // function. + void Release(AHardwareBuffer* buffer) { return release_(buffer); } + + // Like AHardwareBuffer_describe. + // Caller must check that Supported() returns true before calling this + // function. + void Describe(AHardwareBuffer* buffer, AHardwareBuffer_Desc* desc) { + return describe_(buffer, desc); + } + + private: + void* dlopen_handle_; + int (*is_supported_)(const AHardwareBuffer_Desc* desc); + int (*allocate_)(const AHardwareBuffer_Desc* desc, AHardwareBuffer** buffer); + void (*acquire_)(AHardwareBuffer* buffer); + void (*release_)(AHardwareBuffer* buffer); + void (*describe_)(AHardwareBuffer* buffer, AHardwareBuffer_Desc* desc); + bool supported_; + + OptionalAndroidHardwareBuffer(); + OptionalAndroidHardwareBuffer(const OptionalAndroidHardwareBuffer&) = delete; + // Note that we deliberately do not call dlclose() in the destructor; doing + // so would complicate the code and would unnecessarily introduce additional + // failure scenarios. The object is a singleton and so is only destroyed when + // the process is about to exit, and the OS will automatically reclaim the + // resources on process exit anyway, so calling dlclose would only slow down + // process exit. + ~OptionalAndroidHardwareBuffer() = default; +}; + +} // namespace tflite::gpu + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_ANDROID_HARDWARE_BUFFER_H_ diff --git a/tensorflow/lite/delegates/gpu/android_hardware_buffer_test.cc b/tensorflow/lite/delegates/gpu/android_hardware_buffer_test.cc new file mode 100644 index 00000000000000..9f1c35fc5c2d73 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/android_hardware_buffer_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" + +#include + +using tflite::gpu::OptionalAndroidHardwareBuffer; +auto Instance = OptionalAndroidHardwareBuffer::Instance; + +namespace { + +#ifndef __ANDROID__ + +TEST(OptionalAndroidHardwareBufferTest, NotSupportedOnNonAndroid) { + EXPECT_EQ(Instance().Supported(), false); +} + +#else // defined(__ANDROID__) + +TEST(OptionalAndroidHardwareBufferTest, SupportedOnAndroid) { + EXPECT_EQ(Instance().Supported(), true); +} + +TEST(OptionalAndroidHardwareBufferTest, CanAllocateAndReleaseOnAndroid) { + EXPECT_EQ(Instance().Supported(), true); + AHardwareBuffer* buffer; + AHardwareBuffer_Desc description{}; + description.width = 1600; + description.height = 1; + description.layers = 1; + description.rfu0 = 0; + description.rfu1 = 0; + description.stride = 1; + description.format = AHARDWAREBUFFER_FORMAT_BLOB; + description.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN; + EXPECT_TRUE(Instance().IsSupported(&description)); + EXPECT_EQ(Instance().Allocate(&description, &buffer), 0); + Instance().Release(buffer); +} + +TEST(OptionalAndroidHardwareBufferTest, CanAcquireAndReleaseOnAndroid) { + EXPECT_EQ(Instance().Supported(), true); + AHardwareBuffer* buffer; + AHardwareBuffer_Desc description{}; + description.width = 1600; + description.height = 1; + description.layers = 1; + description.rfu0 = 0; + description.rfu1 = 0; + description.stride = 1; + description.format = AHARDWAREBUFFER_FORMAT_BLOB; + description.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN; + EXPECT_TRUE(Instance().IsSupported(&description)); + EXPECT_EQ(Instance().Allocate(&description, &buffer), 0); + Instance().Acquire(buffer); + Instance().Release(buffer); // To match Acquire + Instance().Release(buffer); // To match Allocate +} + +#endif // defined(__ANDROID__) + +} // namespace diff --git a/tensorflow/lite/delegates/gpu/async_buffers.cc b/tensorflow/lite/delegates/gpu/async_buffers.cc index 78e201f2102e6d..3c988857506ffd 100644 --- a/tensorflow/lite/delegates/gpu/async_buffers.cc +++ b/tensorflow/lite/delegates/gpu/async_buffers.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" namespace { @@ -74,11 +75,9 @@ absl::Status AsyncBuffer::AllocateOpenGlBuffer() { if (!status.ok()) { // If we can't map to SSBO, clear AHWB & SSBO if (ahwb_ != nullptr) { -#if (__ANDROID__) - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(ahwb_); + if (OptionalAndroidHardwareBuffer::Instance().Supported()) { + OptionalAndroidHardwareBuffer::Instance().Release(ahwb_); } -#endif ahwb_ = nullptr; } glBufferData(GL_SHADER_STORAGE_BUFFER, bytes_, nullptr, GL_STREAM_COPY); diff --git a/tensorflow/lite/delegates/gpu/async_buffers_test.cc b/tensorflow/lite/delegates/gpu/async_buffers_test.cc index 2f51e408661358..649c41f4be6797 100644 --- a/tensorflow/lite/delegates/gpu/async_buffers_test.cc +++ b/tensorflow/lite/delegates/gpu/async_buffers_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" @@ -28,38 +29,44 @@ namespace gpu { namespace { TEST(AsyncBufferTest, DuplicateTest) { - // Create tie - TensorObjectDef* tie = new TensorObjectDef(); - tie->object_def.data_type = DataType::FLOAT32; - tie->object_def.data_layout = DataLayout::BHWC; - tie->dimensions = Dimensions(2, 2, 2, 2); + if (__builtin_available(android 26, *)) { + auto Instance = OptionalAndroidHardwareBuffer::Instance; + // Create tie + TensorObjectDef* tie = new TensorObjectDef(); + tie->object_def.data_type = DataType::FLOAT32; + tie->object_def.data_layout = DataLayout::BHWC; + tie->dimensions = Dimensions(2, 2, 2, 2); - // Create AHWB - AHardwareBuffer_Desc buffDesc = {}; - buffDesc.width = 1000; - buffDesc.height = 1; - buffDesc.layers = 1; - buffDesc.format = AHARDWAREBUFFER_FORMAT_BLOB; - buffDesc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER; - AHardwareBuffer* ahwb; - EXPECT_EQ(AHardwareBuffer_allocate(&buffDesc, &ahwb), 0); + // Create AHWB + AHardwareBuffer_Desc buffDesc = {}; + buffDesc.width = 1000; + buffDesc.height = 1; + buffDesc.layers = 1; + buffDesc.format = AHARDWAREBUFFER_FORMAT_BLOB; + buffDesc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER; + AHardwareBuffer* ahwb; + EXPECT_TRUE(Instance().IsSupported(&buffDesc)); + EXPECT_EQ(Instance().Allocate(&buffDesc, &ahwb), 0); - // Init GL Env to properly use gl fcns - std::unique_ptr env; - EXPECT_OK(gl::EglEnvironment::NewEglEnvironment(&env)); - AsyncBuffer async_buffer1 = AsyncBuffer(*tie, ahwb); - GLuint buffer1, buffer2; - EXPECT_OK(async_buffer1.GetOpenGlBuffer(buffer1)); - EXPECT_GE(buffer1, 0); - EXPECT_OK(async_buffer1.GetOpenGlBuffer(buffer2)); - // Check that each instance of AsyncBuffer class has only one id - EXPECT_EQ(buffer1, buffer2); - AsyncBuffer async_buffer2 = AsyncBuffer(*tie, ahwb); - EXPECT_OK(async_buffer2.GetOpenGlBuffer(buffer2)); - // Check that each different instance will produce unique id - EXPECT_NE(buffer1, buffer2); + // Init GL Env to properly use gl fcns + std::unique_ptr env; + EXPECT_OK(gl::EglEnvironment::NewEglEnvironment(&env)); + AsyncBuffer async_buffer1 = AsyncBuffer(*tie, ahwb); + GLuint buffer1, buffer2; + EXPECT_OK(async_buffer1.GetOpenGlBuffer(buffer1)); + EXPECT_GE(buffer1, 0); + EXPECT_OK(async_buffer1.GetOpenGlBuffer(buffer2)); + // Check that each instance of AsyncBuffer class has only one id + EXPECT_EQ(buffer1, buffer2); + AsyncBuffer async_buffer2 = AsyncBuffer(*tie, ahwb); + EXPECT_OK(async_buffer2.GetOpenGlBuffer(buffer2)); + // Check that each different instance will produce unique id + EXPECT_NE(buffer1, buffer2); + } else { + GTEST_SKIP(); + } } } // namespace diff --git a/tensorflow/lite/delegates/gpu/build_defs.bzl b/tensorflow/lite/delegates/gpu/build_defs.bzl index e6e6fa2be3934e..cdea91aec86507 100644 --- a/tensorflow/lite/delegates/gpu/build_defs.bzl +++ b/tensorflow/lite/delegates/gpu/build_defs.bzl @@ -1,17 +1,5 @@ """Additional build options needed for the GPU Delegate.""" -# copybara:uncomment_begin(google-only) -# load("//third_party/android/ndk/platforms:grte_top.bzl", "min_supported_ndk_api") -# copybara:uncomment_end - -def nativewindow_linkopts(): - # copybara:uncomment_begin(google-only) - # return min_supported_ndk_api("26", ["-lnativewindow"]) - # copybara:uncomment_end - # copybara:comment_begin(oss-only) - return ["-lnativewindow"] - # copybara:comment_end - def gpu_delegate_linkopts(): """Additional link options needed when linking in the GPU Delegate.""" return select({ @@ -24,7 +12,7 @@ def gpu_delegate_linkopts(): "-lGLESv2", ], "//conditions:default": [], - }) + nativewindow_linkopts() + }) def tflite_angle_heapcheck_deps(): # copybara:uncomment_begin(google-only) @@ -40,3 +28,11 @@ def tflite_angle_heapcheck_deps(): # copybara:comment_begin(oss-only) return ["@com_google_googletest//:gtest_main"] # copybara:comment_end + +def gtest_main_no_heapcheck_deps(): + # copybara:uncomment_begin(google-only) + # return ["@com_google_googletest//:gtest_main_no_heapcheck"] + # copybara:uncomment_end + # copybara:comment_begin(oss-only) + return ["@com_google_googletest//:gtest_main"] + # copybara:comment_end diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 760f401bdc61ea..ae7c23280538d2 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -428,18 +428,9 @@ cc_library( srcs = ["opencl_wrapper.cc"], hdrs = ["opencl_wrapper.h"], linkopts = select({ - "//tensorflow:android": [ - "-ldl", # opencl_wrapper calls dlopen() - "-lm", - ], - # copybara:uncomment_begin(google-only) - # "//tools/cc_target_os:linux-google": [ - # "-ldl", - # "-rdynamic", - # ], - # copybara:uncomment_end - "//conditions:default": ["-ldl"], # opencl_wrapper calls dlopen() - }), + "//tensorflow:android": ["-lm"], + "//conditions:default": [], + }) + ["-ldl"], # opencl_wrapper calls dlopen() deps = [ "//tensorflow/lite/delegates/gpu/common:status", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 889b178463e5f1..23ac928214a42d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -2,6 +2,7 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "tf_gpu_tests_tags", ) +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gtest_main_no_heapcheck_deps") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,14 +19,13 @@ cc_test( "notsan", "requires-gpu-nvidia", ], + # TODO(b/279977471) Once b/279347631 is resolved, check for heap again deps = [ ":cl_test", - # TODO(b/279977471) Once b/279347631 is resolved, check for heap again - "@com_google_googletest//:gtest_main_no_heapcheck", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:add_test_util", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -42,8 +42,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:cast_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_library( @@ -76,8 +75,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:concat_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -94,8 +92,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:conv_constants_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -112,8 +109,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:conv_generic_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -130,8 +126,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:conv_weights_converter_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_library( @@ -173,8 +168,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -191,8 +185,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_3x3_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -209,8 +202,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_3x3_thin_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -227,8 +219,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_4x4_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -245,8 +236,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_thin_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -263,8 +253,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:cumsum_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -281,8 +270,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:depthwise_conv_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -300,8 +288,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:depthwise_conv_3x3_stride_h2_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:depthwise_conv_3x3_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -318,8 +305,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -341,8 +327,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/tasks:fully_connected", "//tensorflow/lite/delegates/gpu/common/tasks:fully_connected_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -359,8 +344,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:gather_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -377,8 +361,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:lstm_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -414,8 +397,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:max_unpooling_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -432,8 +414,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:mean_stddev_normalization_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -450,8 +431,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:one_hot_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -468,8 +448,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:padding_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -486,8 +465,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:pooling_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -504,8 +482,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:prelu_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -522,8 +499,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -540,8 +516,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:reduce_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -558,8 +533,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:relu_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -576,8 +550,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:resampler_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -594,8 +567,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -612,8 +584,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -630,8 +601,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:select_v2_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -648,8 +618,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:softmax_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -666,8 +635,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:softmax_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -684,8 +652,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -702,8 +669,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:split_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -720,8 +686,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:strided_slice_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -738,8 +703,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:tile_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -756,8 +720,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:transpose_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -774,8 +737,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:resize_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -792,8 +754,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:winograd_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) test_suite( diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD index 75e36c0c9ca877..e333bb6daf5628 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gtest_main_no_heapcheck_deps") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], @@ -35,8 +37,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/cl/kernels:cl_test", "//tensorflow/lite/delegates/gpu/common:gpu_model_test_util", "//tensorflow/lite/delegates/gpu/common:status", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), # constant buffers leak on nvidia ) cc_binary( diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 7f50c878263cd9..0a55ad05a76968 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow/core/platform:build_config.bzl", "tf_platform_alias") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("//tensorflow:tensorflow.bzl", "workspace_root") +load("//tensorflow/core/platform:build_config.bzl", "tf_platform_alias") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -299,6 +299,16 @@ cc_library( ], ) +cc_test( + name = "model_builder_helper_test", + srcs = ["model_builder_helper_test.cc"], + deps = [ + ":model_builder_helper", + "//tensorflow/lite/core/c:private_common", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "model_hints", hdrs = ["model_hints.h"], diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 8f41a9ef1a715a..30489a1721f9d8 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -39,6 +39,7 @@ GpuVendor GetGpuVendor(const std::string& gpu_description) { {"nvidia", GpuVendor::kNvidia}, {"amd", GpuVendor::kAMD}, {"radeon", GpuVendor::kAMD}, + {"xclipse", GpuVendor::kAMD}, {"power", GpuVendor::kPowerVR}, }; for (const auto& v : kMapping) { @@ -625,15 +626,6 @@ void GetGpuInfoFromDeviceDescription(const std::string& gpu_description, absl::AsciiStrToLower(&lowered); gpu_info->vendor = GetGpuVendor(lowered); - // Because clvk is an OpenCL layer on top of vulkan, it does not react to CL - // optimisation as native CL implementation does. - // AMD is particularly affected, thus let's manage it differently to get the - // best performances out of it. - if (gpu_info->IsApiOpenCl() && gpu_info->opencl_info.IsCLVK() && - gpu_info->IsAMD()) { - gpu_info->vendor = GpuVendor::kUnknown; - } - if (gpu_info->IsAdreno()) { gpu_info->adreno_info = AdrenoInfo(lowered); } else if (gpu_info->IsApple()) { diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc index b498916bfed447..66224878c617ae 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc @@ -16,10 +16,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" #include -#include -#include -#include +#include +#include #include #include #include @@ -257,34 +256,32 @@ void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src, } template <> -absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, - float* tensor_data) { - switch (tensor.type) { +absl::Status CreateVectorCopyData(const TfLiteTensor& src, float* dst) { + switch (src.type) { case kTfLiteFloat32: - std::memcpy(tensor_data, tensor.data.f, tensor.bytes); - break; + std::memcpy(dst, src.data.f, src.bytes); + return absl::OkStatus(); case kTfLiteFloat16: - ConvertFloat16ToFloat32( - NumElements(&tensor), - reinterpret_cast(tensor.data.f16), tensor_data); - break; + ConvertFloat16ToFloat32(NumElements(&src), + reinterpret_cast(src.data.f16), + dst); + return absl::OkStatus(); case kTfLiteInt8: - DequantizeConstantTensor(tensor, tensor.data.int8, tensor_data); - break; + DequantizeConstantTensor(src, src.data.int8, dst); + return absl::OkStatus(); case kTfLiteUInt8: - DequantizeConstantTensor(tensor, tensor.data.uint8, tensor_data); - break; + DequantizeConstantTensor(src, src.data.uint8, dst); + return absl::OkStatus(); case kTfLiteInt32: - DequantizeConstantTensor(tensor, tensor.data.i32, tensor_data); - break; + DequantizeConstantTensor(src, src.data.i32, dst); + return absl::OkStatus(); default: return absl::InvalidArgumentError( "Unsupported data type for float32 tensor"); } - return absl::OkStatus(); } -const std::string GetDimensionString(const TfLiteIntArray* dimensions) { +std::string GetDimensionString(const TfLiteIntArray* dimensions) { return absl::StrJoin(TfLiteIntArrayView(dimensions), "x"); } diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h index 88a4576d45d9bf..14384ce5be9a1c 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h @@ -17,10 +17,9 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ #include -#include -#include -#include +#include +#include #include "absl/strings/str_cat.h" #include "tensorflow/lite/core/c/builtin_op_data.h" @@ -33,6 +32,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace gpu { @@ -100,19 +100,94 @@ inline void DequantizeConstantTensor(const TfLiteTensor& tensor, } template -absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) { - if (tensor.bytes % sizeof(T) != 0) { +absl::Status CreateVectorCopyData(const TfLiteTensor& src, T* dst) { + if (src.bytes % sizeof(T) != 0) { return absl::InvalidArgumentError( - absl::StrCat("Input data size ", tensor.bytes, + absl::StrCat("Input data size ", src.bytes, " is not aligned to expected type: ", sizeof(T))); } - std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes); - return absl::OkStatus(); + if (const int n = tflite::NumElements(&src); n * sizeof(T) == src.bytes) { + std::memcpy(dst, src.data.raw_const, src.bytes); + return absl::OkStatus(); + } else { + switch (src.type) { + case kTfLiteNoType: + return absl::InvalidArgumentError("src has no type."); + case kTfLiteFloat32: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteInt32: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteUInt8: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteInt64: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteString: + return absl::UnimplementedError("src can't be string."); + case kTfLiteBool: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteInt16: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteComplex64: + return absl::UnimplementedError("src can't be complex64."); + case kTfLiteInt8: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteFloat16: + return absl::UnimplementedError("src can't be float16."); + case kTfLiteFloat64: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteComplex128: + return absl::UnimplementedError("src can't be complex128."); + case kTfLiteUInt64: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteResource: + return absl::UnimplementedError("src can't be resource."); + case kTfLiteVariant: + return absl::UnimplementedError("src can't be variant."); + case kTfLiteUInt32: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteUInt16: + for (int i = 0; i < n; ++i) { + dst[i] = tflite::GetTensorData(&src)[i]; + } + return absl::OkStatus(); + case kTfLiteInt4: + return absl::UnimplementedError("src can't be int4."); + } + } } template <> -absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, - float* tensor_data); +absl::Status CreateVectorCopyData(const TfLiteTensor& src, float* dst); absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc new file mode 100644 index 00000000000000..f13bc539785467 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" + +#include + +#include +#include +#include "tensorflow/lite/core/c/common.h" + +namespace tflite { +namespace gpu { +namespace { + +using ::testing::ElementsAre; + +TEST(ModelBuilderHelperTest, CreateVectorCopyDataDifferentSize) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = kTfLiteInt32; + int32_t src_data[4] = {1, 2, 3, 4}; + tflite_tensor.data.i32 = src_data; + tflite_tensor.dims = TfLiteIntArrayCreate(1); + tflite_tensor.dims->data[0] = sizeof(src_data) / sizeof(src_data[0]); + tflite_tensor.bytes = sizeof(src_data); + + int16_t dst[4]; + ASSERT_OK(CreateVectorCopyData(tflite_tensor, dst)); + EXPECT_THAT(dst, ElementsAre(1, 2, 3, 4)); + + TfLiteIntArrayFree(tflite_tensor.dims); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc index 72e54dd21c94f5..edcdff50dbc991 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc @@ -886,7 +886,8 @@ std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info, std::to_string(s * 4 + ch + shared_offset); std::string w_val; if (conv_params.AreWeightsBuffer()) { - if (gpu_info.SupportsPointersInKernels()) { + if (need_local_mem || + gpu_info.SupportsPointersInKernels()) { w_val = "weights_cache[" + weight_id + "]"; } else { w_val = "args.weights.Read(filters_offset + " + @@ -926,7 +927,7 @@ std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info, std::string weight_id = std::to_string(s * 4 + i + shared_offset); if (conv_params.AreWeightsBuffer()) { - if (gpu_info.SupportsPointersInKernels()) { + if (need_local_mem || gpu_info.SupportsPointersInKernels()) { F[i] = "weights_cache[" + weight_id + "]"; } else { F[i] = @@ -1113,7 +1114,7 @@ std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info, c += " if (DST_S + " + sind + " >= args.dst_tensor.Slices()) return;\n"; c += " {\n"; if (conv_params.AreWeightsBuffer() && - gpu_info.SupportsPointersInKernels()) { + (need_local_mem || gpu_info.SupportsPointersInKernels())) { c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sind + "]);\n"; } else { c += " FLT4 bias_val = args.biases.Read(DST_S + " + sind + ");\n"; @@ -1748,8 +1749,7 @@ ConvGeneric::ConvParams ConvGeneric::GuessBestParams( conv_params.fixed_work_group_size = false; conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM_X4; - } else if (gpu_info.IsIntel() || - (gpu_info.IsApiOpenCl() && gpu_info.opencl_info.IsCLVK())) { + } else if (gpu_info.IsIntel()) { if (different_weights_for_height) { work_group_size_ = int3(16, 1, 1); work_group_launch_order_ = int3(0, 1, 2); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD index 74af3502abc116..86ab973e974763 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gtest_main_no_heapcheck_deps") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], @@ -26,18 +28,17 @@ cc_test( "notsan", "requires-gpu-nvidia", ], + # TODO(b/279977471) Once b/279347631 is resolved, check for heap again deps = [ ":conv_pointwise", - # TODO(b/279977471) Once b/279347631 is resolved, check for heap again - "@com_google_googletest//:gtest_main_no_heapcheck", "//tensorflow/lite/delegates/gpu/cl/kernels:cl_test", - "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", - "//tensorflow/lite/delegates/gpu/common/task:testing_util", "//tensorflow/lite/delegates/gpu/common:precision", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", - ], + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:testing_util", + ] + gtest_main_no_heapcheck_deps(), ) cc_library( diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc index e77a488587df78..0af40dfaf8f07a 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc @@ -59,11 +59,11 @@ TEST_F(OpenCLOperationTest, SliceMulMeanConcat) { op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; GPUOperation operation = CreateConvPointwise(op_def, op_attr); - EXPECT_OK(env->ExecuteGPUOperation( + ASSERT_OK(env->ExecuteGPUOperation( {src_tensor, weights_tensor}, std::make_unique(std::move(operation)), BHWC(1, 2, 1, 2), &dst_tensor)); - EXPECT_OK(PointWiseNear({5.5f, 5.5f, 8.5f, 8.5f}, dst_tensor.data, eps)); + ASSERT_OK(PointWiseNear({5.5f, 5.5f, 8.5f, 8.5f}, dst_tensor.data, eps)); } } } @@ -93,11 +93,11 @@ TEST_F(OpenCLOperationTest, SliceMulSumConcat) { op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; GPUOperation operation = CreateConvPointwise(op_def, op_attr); - EXPECT_OK(env->ExecuteGPUOperation( + ASSERT_OK(env->ExecuteGPUOperation( {src_tensor, weights_tensor}, std::make_unique(std::move(operation)), BHWC(1, 2, 1, 2), &dst_tensor)); - EXPECT_OK( + ASSERT_OK( PointWiseNear({11.0f, 11.0f, 17.0f, 17.0f}, dst_tensor.data, eps)); } } diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index c7232f69078de5..00c58d37e6f37b 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -46,6 +46,7 @@ limitations under the License. #endif #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" @@ -105,17 +106,9 @@ using tflite::delegates::utils::WriteSyncAttrs; } while (false) // This idiom allows selecting alternate code paths depending on whether or not -// AHWB is available. However, it's still necessary to directly guard calls to -// AHardwareBuffer_* functions with "if (__builtin_available(android 26, *))" to -// avoid compiler errors. -#define TFLITE_AHWB_AVAILABLE() \ - [] { \ - if (__builtin_available(android 26, *)) { \ - return true; \ - } else { \ - return false; \ - } \ - }() +// AHWB is available. +#define TFLITE_AHWB_AVAILABLE() \ + ::tflite::gpu::OptionalAndroidHardwareBuffer::Instance().Supported() namespace tflite { namespace gpu { @@ -772,24 +765,24 @@ class DelegateAsyncKernel : public BackendAsyncKernelInterface { using UniquePtrAHardwareBuffer = std::unique_ptr; static UniquePtrAHardwareBuffer Acquire(AHardwareBuffer* ahwb) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_acquire(ahwb); + if (OptionalAndroidHardwareBuffer::Instance().Supported()) { + OptionalAndroidHardwareBuffer::Instance().Acquire(ahwb); + return UniquePtrAHardwareBuffer(ahwb, [](AHardwareBuffer* b) { + OptionalAndroidHardwareBuffer::Instance().Release(b); + }); } else { TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "attempting AHardwareBuffer_acquire on a device without " "AHardwareBuffer support"); + return {nullptr, [](AHardwareBuffer*) {}}; } - return UniquePtrAHardwareBuffer(ahwb, [](AHardwareBuffer* b) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(b); - } - }); } static AHardwareBuffer_Desc Describe( const UniquePtrAHardwareBuffer& uptr_ahwb) { AHardwareBuffer_Desc desc_ahwb = {}; - if (__builtin_available(android 26, *)) { - AHardwareBuffer_describe(uptr_ahwb.get(), &desc_ahwb); + if (OptionalAndroidHardwareBuffer::Instance().Supported()) { + OptionalAndroidHardwareBuffer::Instance().Describe(uptr_ahwb.get(), + &desc_ahwb); } else { TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "attempting AHardwareBuffer_describe on a device without " diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index 8571ff7f04156c..0c555a4e6b9a12 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -90,7 +90,7 @@ objc_library( ios_unit_test( name = "common_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -144,7 +144,6 @@ objc_library( copts = DEFAULT_COPTS + [ "-ObjC++", ], - features = ["-layering_check"], sdk_frameworks = ["Metal"], deps = [ ":compute_task", @@ -268,7 +267,7 @@ ios_application( "iphone", ], infoplists = ["Info.plist"], - minimum_os_version = "11.4", + minimum_os_version = "12.0", provisioning_profile = "//tensorflow/lite/delegates/gpu/metal:provisioning_profile.mobileprovision", tags = tf_gpu_tests_tags() + [ "local", @@ -298,7 +297,7 @@ objc_library( ios_unit_test( name = "ComponentsTests", - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + ["notap"], test_host = ":TestApplication", diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index 72206b7678b140..06c295646eec6e 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -30,7 +30,7 @@ objc_library( ios_unit_test( name = "add_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -54,7 +54,7 @@ objc_library( ios_unit_test( name = "cast_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -77,7 +77,7 @@ objc_library( ios_unit_test( name = "concat_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -111,7 +111,7 @@ objc_library( ios_unit_test( name = "conv_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -134,7 +134,7 @@ objc_library( ios_unit_test( name = "conv_weights_converter_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -158,7 +158,7 @@ objc_library( ios_unit_test( name = "cumsum_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -183,7 +183,7 @@ objc_library( ios_unit_test( name = "depthwise_conv_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -207,7 +207,7 @@ objc_library( ios_unit_test( name = "elementwise_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -230,7 +230,7 @@ objc_library( ios_unit_test( name = "fully_connected_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -253,7 +253,7 @@ objc_library( ios_unit_test( name = "gather_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -276,7 +276,7 @@ objc_library( ios_unit_test( name = "lstm_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -299,7 +299,7 @@ objc_library( ios_unit_test( name = "max_unpooling_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -322,7 +322,7 @@ objc_library( ios_unit_test( name = "mean_stddev_normalization_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -346,7 +346,7 @@ objc_library( ios_unit_test( name = "one_hot_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -369,7 +369,7 @@ objc_library( ios_unit_test( name = "padding_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -392,7 +392,7 @@ objc_library( ios_unit_test( name = "pooling_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -415,7 +415,7 @@ objc_library( ios_unit_test( name = "prelu_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -443,7 +443,7 @@ objc_library( ios_unit_test( name = "quantize_and_dequantize_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -466,7 +466,7 @@ objc_library( ios_unit_test( name = "reduce_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -489,7 +489,7 @@ objc_library( ios_unit_test( name = "relu_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -514,7 +514,7 @@ objc_library( ios_unit_test( name = "resampler_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = [ "no_mac", # TODO(b/183905399) @@ -539,7 +539,7 @@ objc_library( ios_unit_test( name = "resize_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -562,7 +562,7 @@ objc_library( ios_unit_test( name = "reshape_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -585,7 +585,7 @@ objc_library( ios_unit_test( name = "select_v2_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -608,7 +608,7 @@ objc_library( ios_unit_test( name = "slice_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -631,7 +631,7 @@ objc_library( ios_unit_test( name = "softmax_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -655,7 +655,7 @@ objc_library( ios_unit_test( name = "space_to_depth_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -678,7 +678,7 @@ objc_library( ios_unit_test( name = "split_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -701,7 +701,7 @@ objc_library( ios_unit_test( name = "tile_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -728,7 +728,7 @@ objc_library( ios_unit_test( name = "transpose_conv_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -751,7 +751,7 @@ objc_library( ios_unit_test( name = "transpose_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", @@ -804,7 +804,7 @@ objc_library( ios_unit_test( name = "winograd_test", testonly = 1, - minimum_os_version = "11.4", + minimum_os_version = "12.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = tf_gpu_tests_tags() + [ "notap", diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index affa0600b1ed6d..ad5755f6c990a3 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -27,6 +27,12 @@ config_setting( define_values = {"xnnpack_force_float_precision": "fp16"}, ) +# Force XNNPACK to use all operators in the delegate. +config_setting( + name = "xnnpack_use_latest_ops_explicit", + define_values = {"xnnpack_use_latest_ops": "true"}, +) + # Enable offloading of quantized 8-bit signed operators to XNNPACK delegate config_setting( name = "tflite_with_xnnpack_qs8_explicit_true", @@ -214,6 +220,9 @@ cc_library( copts = tflite_copts() + select({ ":xnnpack_force_float_precision_explicit_fp16": ["-DXNNPACK_DELEGATE_FORCE_PRECISION_FP16=1"], "//conditions:default": [], + }) + select({ + ":xnnpack_use_latest_ops_explicit": ["-DXNNPACK_DELEGATE_USE_LATEST_OPS=1"], + "//conditions:default": [], }), linkstatic = True, deps = [ diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index ca7bdfc88804ed..498c7b0e5b7fda 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -532,8 +532,12 @@ class Delegate { } bool enable_latest_operators() const { +#ifdef XNNPACK_DELEGATE_USE_LATEST_OPS + return true; +#else return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS) != 0; +#endif } bool support_variable_ops() const { @@ -3688,8 +3692,9 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( logging_context, filter_tensor, node->inputs->data[1], node_index)); } else { - TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( - delegate, logging_context, filter_tensor, node->inputs->data[1], + TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQCInt8Type( + delegate, logging_context, filter_tensor, + /*expected_quantized_dimension=*/0, node->inputs->data[1], node_index)); if (quasi_static_tensors.count(node->inputs->data[1]) == 0) { TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index 42e056796ef14f..803a2a89c931ef 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -263,6 +263,10 @@ void RunInference(Settings* settings, LOG(INFO) << "number of outputs: " << outputs.size(); } + auto profiler = std::make_unique( + settings->max_profiling_buffer_entries); + interpreter->SetProfiler(profiler.get()); + auto delegates = delegate_providers.CreateAllDelegates(); for (auto& delegate : delegates) { const auto delegate_name = delegate.provider->GetName(); @@ -311,9 +315,6 @@ void RunInference(Settings* settings, << interpreter->tensor(input)->type << " yet"; exit(-1); } - auto profiler = std::make_unique( - settings->max_profiling_buffer_entries); - interpreter->SetProfiler(profiler.get()); if (settings->profiling) profiler->StartProfiling(); for (int i = 0; i < settings->number_of_warmup_runs; i++) { diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin index 8108897c68e54b..417b66385ccfe7 100644 Binary files a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin and b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin differ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h index e62b599d7e5294..011e492048f352 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h + #include "tensorflow/lite/core/experimental/acceleration/mini_benchmark/c/c_api.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_C_C_API_H_ diff --git a/tensorflow/lite/experimental/microfrontend/BUILD b/tensorflow/lite/experimental/microfrontend/BUILD index 1fb94ff67d2dea..e1c4f30baa7ffd 100644 --- a/tensorflow/lite/experimental/microfrontend/BUILD +++ b/tensorflow/lite/experimental/microfrontend/BUILD @@ -118,8 +118,8 @@ tf_custom_op_py_strict_library( srcs_version = "PY3", deps = [ ":audio_microfrontend_op", - "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:load_library", "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", diff --git a/tensorflow/lite/g3doc/microcontrollers/build_convert.md b/tensorflow/lite/g3doc/microcontrollers/build_convert.md index 402fc1eb26bd54..6c605d6a38a549 100644 --- a/tensorflow/lite/g3doc/microcontrollers/build_convert.md +++ b/tensorflow/lite/g3doc/microcontrollers/build_convert.md @@ -10,9 +10,8 @@ microcontrollers. It also outlines the supported operations and gives some guidance on designing and training a model to fit in limited memory. For an end-to-end, runnable example of building and converting a model, see the -following Colab which is part of the *Hello World* example: - -train_hello_world_model.ipynb +[Hello World](https://github.com/tensorflow/tflite-micro/tree/main/tensorflow/lite/micro/examples/hello_world#hello-world-example) +example. ## Model conversion @@ -54,7 +53,7 @@ important to change the array declaration to `const` for better memory efficiency on embedded platforms. For an example of how to include and use a model in your program, see -[`evaluate_test.cc`](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/hello_world/evaluate_test.cc) +[`hello_world_test.cc`](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/hello_world/hello_world_test.cc) in the *Hello World* example. ## Model architecture and training diff --git a/tensorflow/lite/interpreter_builder.h b/tensorflow/lite/interpreter_builder.h index 01dfefe8a43ed8..346e08ed7cea22 100644 --- a/tensorflow/lite/interpreter_builder.h +++ b/tensorflow/lite/interpreter_builder.h @@ -17,7 +17,7 @@ limitations under the License. /// For documentation, see third_party/tensorflow/lite/core/interpreter_builder.h. -#include "tensorflow/lite/core/interpreter_builder.h" +#include "tensorflow/lite/core/interpreter_builder.h" // IWYU pragma: export namespace tflite { using InterpreterBuilder = ::tflite::impl::InterpreterBuilder; diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 3c0f0265236d5b..3f9fe7fea2a364 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -465,6 +465,14 @@ java_library( ], ) +java_library_with_tflite( + name = "test_init", + testonly = True, + srcs = [ + "src/test/java/org/tensorflow/lite/TestInit.java", + ], +) + #----------------------------------------------------------------------------- # java_library targets that also include native code dependencies. @@ -516,7 +524,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java", - "src/test/java/org/tensorflow/lite/TestInit.java", ], javacopts = JAVACOPTS, # We want to ensure that every test case in the test also verifies that the @@ -532,6 +539,9 @@ java_test_with_tflite( "v1only", ], test_class = "org.tensorflow.lite.TensorFlowLiteTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -597,7 +607,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java", - "src/test/java/org/tensorflow/lite/TestInit.java", ], data = [ # The files named as .bin reshape the incoming tensor from (2, 8, 8, 3) to (2, 4, 4, 12). @@ -613,6 +622,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -631,7 +643,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -646,6 +657,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.InterpreterTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -663,7 +677,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterApiTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -677,6 +690,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.InterpreterApiTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_stable_test_jni.so", ], @@ -695,7 +711,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterApiNoRuntimeTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -703,6 +718,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.InterpreterApiNoRuntimeTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_stable_test_jni.so", ], @@ -720,7 +738,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/NnApiDelegateNativeTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -728,6 +745,9 @@ java_test_with_tflite( ], tags = ["no_mac"], test_class = "org.tensorflow.lite.NnApiDelegateNativeTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -745,7 +765,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", "src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java", ], @@ -755,6 +774,9 @@ java_test_with_tflite( javacopts = JAVACOPTS, tags = ["no_mac"], test_class = "org.tensorflow.lite.nnapi.NnApiDelegateTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -797,7 +819,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/TensorTest.java", - "src/test/java/org/tensorflow/lite/TestInit.java", ], data = [ "src/testdata/add.bin", @@ -808,6 +829,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.TensorTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java index cf1371fb9dabba..33a5d41a9b6b03 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -32,7 +32,7 @@ public final class TensorFlowLite { // will discard those, and avoid logging messages with parameters (call String.format instead), // since the default Java log handler on Android only logs the raw message string and doesn't // apply the parameters. - private static final Logger logger = Logger.getLogger(InterpreterApi.class.getName()); + private static final Logger logger = Logger.getLogger(TensorFlowLite.class.getName()); private static final String[][] TFLITE_RUNTIME_LIBNAMES = new String[][] { diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TestInit.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TestInit.java index 3c46bc09ba3ee0..32b4ccf216c613 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TestInit.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TestInit.java @@ -14,8 +14,11 @@ ==============================================================================*/ package org.tensorflow.lite; +import java.util.logging.Logger; + /** Utilities for initializing TF Lite for tests. */ public final class TestInit { + private static final Logger logger = Logger.getLogger(TestInit.class.getName()); private TestInit() {} @@ -29,8 +32,15 @@ public static void init() { if (!initialized) { try { System.loadLibrary("tensorflowlite_test_jni"); + logger.info("Loaded native library for tests: tensorflowlite_test_jni"); } catch (UnsatisfiedLinkError e) { - System.loadLibrary("tensorflowlite_stable_test_jni"); + logger.info("Didn't load native library for tests: tensorflowlite_test_jni"); + try { + System.loadLibrary("tensorflowlite_stable_test_jni"); + logger.info("Loaded native library for tests: tensorflowlite_stable_test_jni"); + } catch (UnsatisfiedLinkError e2) { + logger.info("Didn't load native library for tests: tensorflowlite_stable_test_jni"); + } } initTfLiteForTest(); initialized = true; diff --git a/tensorflow/lite/java/src/test/native/BUILD b/tensorflow/lite/java/src/test/native/BUILD index 2c32fc618331a8..db20aafcd2d0f8 100644 --- a/tensorflow/lite/java/src/test/native/BUILD +++ b/tensorflow/lite/java/src/test/native/BUILD @@ -30,10 +30,9 @@ cc_library_with_tflite( "interpreter_test_jni.cc", "nnapi_delegate_test_jni.cc", "supported_features_jni.cc", - "test_init_jni.cc", ], tflite_deps = [ - "//tensorflow/lite/c:test_util", + ":test_init_jni", "//tensorflow/lite/delegates/nnapi/java/src/main/native", "//tensorflow/lite/java/src/main/native", "//tensorflow/lite/java/src/main/native:jni_utils", @@ -51,6 +50,22 @@ cc_library_with_tflite( alwayslink = 1, ) +cc_library_with_tflite( + name = "test_init_jni", + testonly = 1, + srcs = [ + "test_init_jni.cc", + ], + tflite_deps = [ + "//tensorflow/lite/java/src/main/native:jni_utils", + "//tensorflow/lite/c:test_util", + ], + deps = [ + "//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) + # Same as "native", but excluding dependencies on experimental features. cc_library_with_tflite( name = "native_stable", diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index bf2d189075d294..a0eb5ce425fc2a 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -301,7 +301,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":op_macros", - "//tensorflow/lite:arena_planner", + "//tensorflow/lite:util", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels/internal:optimized_eigen", ], @@ -767,6 +767,7 @@ BUILTIN_KERNEL_SRCS = [ "stablehlo_gather.cc", "stablehlo_add.cc", "stablehlo_multiply.cc", + "stablehlo_pad.cc", "stablehlo_reduce_window.cc", "stablehlo_min_max.cc", "stablehlo_scatter.cc", @@ -1338,6 +1339,28 @@ cc_test( ], ) +cc_test( + name = "stablehlo_pad_test", + srcs = ["stablehlo_pad_test.cc"], + tags = ["tflite_nnapi"], + deps = [ + ":stablehlo_reduce_window_test_util", + ":test_main", + ":test_util", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "stablehlo_reduce_window_test_util", hdrs = ["stablehlo_reduce_window_test_util.h"], @@ -1358,7 +1381,6 @@ cc_test( cc_test( name = "stablehlo_reduce_window_test", - size = "small", srcs = ["stablehlo_reduce_window_test.cc"], tags = ["tflite_nnapi"], deps = [ diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index 54c4ccdd48838f..7b1a0975b0e26a 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_BUILTIN_OP_KERNELS_H_ #define TENSORFLOW_LITE_KERNELS_BUILTIN_OP_KERNELS_H_ +/// For documentation, see +/// third_party/tensorflow/lite/core/kernels/builtin_op_kernels.h + #include "tensorflow/lite/core/kernels/builtin_op_kernels.h" namespace tflite { diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index 13374c41958ef5..af91b0a6de7336 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -176,7 +176,7 @@ template void Gemm(const MatrixParams& lhs_params, const int8_t* lhs_data, const MatrixParams& rhs_params, const int16_t* rhs_data, const MatrixParams& dst_params, int16_t* dst_data, - const GemmParams& params, + const GemmParams& params, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm"); ValidateParams(lhs_params, rhs_params, dst_params, params); @@ -187,7 +187,7 @@ void Gemm(const MatrixParams& lhs_params, const int8_t* lhs_data, // Currently, only Ruy backend supports 16x8 quant gemm so we use ruy // only. - detail::GemmImplUsingRuy::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, diff --git a/tensorflow/lite/kernels/eigen_support.cc b/tensorflow/lite/kernels/eigen_support.cc index 0dc977e876cfbf..22cf62d36d14e5 100644 --- a/tensorflow/lite/kernels/eigen_support.cc +++ b/tensorflow/lite/kernels/eigen_support.cc @@ -18,11 +18,14 @@ limitations under the License. #include #include -#include "tensorflow/lite/arena_planner.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/lite/kernels/op_macros.h" +#ifndef EIGEN_DONT_ALIGN +#include "tensorflow/lite/util.h" +#endif // EIGEN_DONT_ALIGN + namespace tflite { namespace eigen_support { namespace { @@ -38,12 +41,11 @@ int GetNumThreads(int num_threads) { #ifndef EIGEN_DONT_ALIGN // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on -// hardware architecture and build configurations. -// If the static assertion fails, try to increase `kDefaultTensorAlignment` to -// in `arena_planner.h` to 32 or 64. +// hardware architecture and build configurations. If the static assertion +// fails, try to increase `kDefaultTensorAlignment` in `util.h` to 32 or 64. static_assert( kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0, - "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement."); + "kDefaultTensorAlignment doesn't comply with Eigen alignment requirement."); #endif // EIGEN_DONT_ALIGN // Helper routine for updating the global Eigen thread count used for OpenMP. diff --git a/tensorflow/lite/kernels/gather_nd.cc b/tensorflow/lite/kernels/gather_nd.cc index 20224c01d86a89..10a62047375673 100644 --- a/tensorflow/lite/kernels/gather_nd.cc +++ b/tensorflow/lite/kernels/gather_nd.cc @@ -50,6 +50,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: case kTfLiteInt32: case kTfLiteString: + case kTfLiteBool: break; default: TF_LITE_KERNEL_LOG(context, @@ -157,6 +158,9 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params, case kTfLiteString: status = GatherNdString(params, indices, output); break; + case kTfLiteBool: + status = GatherNd(params, indices, output); + break; default: TF_LITE_KERNEL_LOG(context, "Params type '%s' are not supported by gather_nd.", diff --git a/tensorflow/lite/kernels/internal/optimized/neon_check.h b/tensorflow/lite/kernels/internal/optimized/neon_check.h index bbf745ce1d12c7..8fdaeef44598d0 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_check.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_check.h @@ -17,12 +17,12 @@ limitations under the License. #if defined(__ARM_NEON__) || defined(__ARM_NEON) #define USE_NEON -#include +#include // IWYU pragma: export #endif #if defined __GNUC__ && defined __SSE4_1__ && !defined TF_LITE_DISABLE_X86_NEON #define USE_NEON -#include "NEON_2_SSE.h" +#include "NEON_2_SSE.h" // IWYU pragma: export #endif // NEON_OR_PORTABLE(SomeFunc, args) calls NeonSomeFunc(args) if USE_NEON is diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 5609719398fcee..3299f610697bbf 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -49,8 +49,8 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/cppmath.h" -#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" +#include "tensorflow/lite/kernels/internal/optimized/neon_check.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" @@ -310,18 +310,18 @@ inline void FullyConnected( inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, CpuBackendContext* cpu_backend_context) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("FullyConnected/8bit"); - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); // TODO(b/62193649): This really should be: @@ -341,26 +341,26 @@ inline void FullyConnected( TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); } - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = filter_rows; lhs_params.cols = filter_cols; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.zero_point = -filter_offset; lhs_params.cache_policy = cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable); - cpu_backend_gemm::MatrixParams rhs_params; + cpu_backend_gemm::MatrixParams rhs_params; rhs_params.rows = filter_cols; rhs_params.cols = batches; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.zero_point = -input_offset; rhs_params.cache_policy = cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable); - cpu_backend_gemm::MatrixParams dst_params; + cpu_backend_gemm::MatrixParams dst_params; dst_params.rows = filter_rows; dst_params.cols = batches; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.zero_point = output_offset; - cpu_backend_gemm::GemmParams gemm_params; + cpu_backend_gemm::GemmParams gemm_params; gemm_params.bias = bias_data; gemm_params.clamp_min = output_activation_min; gemm_params.clamp_max = output_activation_max; @@ -373,18 +373,18 @@ inline void FullyConnected( inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data_int32, const RuntimeShape& output_shape, - int16* output_data, CpuBackendContext* cpu_backend_context) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data_int32, const RuntimeShape& output_shape, + int16_t* output_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16"); - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_EQ(output_offset, 0); TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); @@ -402,26 +402,26 @@ inline void FullyConnected( output_shape, output_dim_count - 1); const int accum_depth = filter_shape.Dims(filter_dim_count - 1); - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = output_depth; lhs_params.cols = accum_depth; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.zero_point = -filter_offset; lhs_params.cache_policy = cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable); - cpu_backend_gemm::MatrixParams rhs_params; + cpu_backend_gemm::MatrixParams rhs_params; rhs_params.rows = accum_depth; rhs_params.cols = batches; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.zero_point = -input_offset; rhs_params.cache_policy = cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable); - cpu_backend_gemm::MatrixParams dst_params; + cpu_backend_gemm::MatrixParams dst_params; dst_params.rows = output_depth; dst_params.cols = batches; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.zero_point = 0; - cpu_backend_gemm::GemmParams gemm_params; + cpu_backend_gemm::GemmParams gemm_params; gemm_params.bias = bias_data_int32; gemm_params.clamp_min = output_activation_min; gemm_params.clamp_max = output_activation_max; @@ -438,12 +438,12 @@ inline void FullyConnected( // as the 'task' for worker threads to run (multi-threaded case, see // ShuffledFullyConnectedWorkerTask below). inline void ShuffledFullyConnectedWorkerImpl( - const uint8* shuffled_input_workspace_data, - const int8* shuffled_weights_data, int batches, int output_depth, - int output_stride, int accum_depth, const int32* bias_data, - int32 output_multiplier, int output_shift, int16* output_data) { + const uint8_t* shuffled_input_workspace_data, + const int8_t* shuffled_weights_data, int batches, int output_depth, + int output_stride, int accum_depth, const int32_t* bias_data, + int32_t output_multiplier, int output_shift, int16_t* output_data) { #if defined USE_NEON - const int8* shuffled_weights_ptr = shuffled_weights_data; + const int8_t* shuffled_weights_ptr = shuffled_weights_data; if (batches == 1) { const int right_shift = output_shift > 0 ? 0 : -output_shift; const int left_shift = output_shift > 0 ? output_shift : 0; @@ -515,8 +515,8 @@ inline void ShuffledFullyConnectedWorkerImpl( const int right_shift = output_shift > 0 ? 0 : -output_shift; const int left_shift = output_shift > 0 ? output_shift : 0; for (int c = 0; c < output_depth; c += 4) { - const int8* shuffled_input_ptr = - reinterpret_cast(shuffled_input_workspace_data); + const int8_t* shuffled_input_ptr = + reinterpret_cast(shuffled_input_workspace_data); // Accumulation loop. int32x4_t row_accum00 = vdupq_n_s32(0); int32x4_t row_accum10 = vdupq_n_s32(0); @@ -613,26 +613,26 @@ inline void ShuffledFullyConnectedWorkerImpl( } #else if (batches == 1) { - int16* output_ptr = output_data; + int16_t* output_ptr = output_data; // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd) - // so that just reinterpreting them as int8 values is equivalent to + // so that just reinterpreting them as int8_t values is equivalent to // subtracting 128 from them, thus implementing for free the subtraction of // the zero_point value 128. - const int8* shuffled_weights_ptr = - reinterpret_cast(shuffled_weights_data); + const int8_t* shuffled_weights_ptr = + reinterpret_cast(shuffled_weights_data); // Likewise, we preshuffled and pre-xored the input data above. - const int8* shuffled_input_data = - reinterpret_cast(shuffled_input_workspace_data); + const int8_t* shuffled_input_data = + reinterpret_cast(shuffled_input_workspace_data); for (int c = 0; c < output_depth; c += 4) { // Internal accumulation. // Initialize accumulator with the bias-value. - int32 accum[4] = {0}; + int32_t accum[4] = {0}; // Accumulation loop. for (int d = 0; d < accum_depth; d += 16) { for (int i = 0; i < 4; i++) { for (int j = 0; j < 16; j++) { - int8 input_val = shuffled_input_data[d + j]; - int8 weights_val = *shuffled_weights_ptr++; + int8_t input_val = shuffled_input_data[d + j]; + int8_t weights_val = *shuffled_weights_ptr++; accum[i] += weights_val * input_val; } } @@ -640,35 +640,35 @@ inline void ShuffledFullyConnectedWorkerImpl( for (int i = 0; i < 4; i++) { // Add bias value int acc = accum[i] + bias_data[c + i]; - // Down-scale the final int32 accumulator to the scale used by our + // Down-scale the final int32_t accumulator to the scale used by our // (16-bit, typically 3 integer bits) fixed-point format. The quantized // multiplier and shift here have been pre-computed offline // (e.g. by toco). acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); - // Saturate, cast to int16, and store to output array. + // Saturate, cast to int16_t, and store to output array. acc = std::max(acc, -32768); acc = std::min(acc, 32767); output_ptr[c + i] = acc; } } } else if (batches == 4) { - int16* output_ptr = output_data; + int16_t* output_ptr = output_data; // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd) - // so that just reinterpreting them as int8 values is equivalent to + // so that just reinterpreting them as int8_t values is equivalent to // subtracting 128 from them, thus implementing for free the subtraction of // the zero_point value 128. - const int8* shuffled_weights_ptr = - reinterpret_cast(shuffled_weights_data); + const int8_t* shuffled_weights_ptr = + reinterpret_cast(shuffled_weights_data); // Likewise, we preshuffled and pre-xored the input data above. - const int8* shuffled_input_data = - reinterpret_cast(shuffled_input_workspace_data); + const int8_t* shuffled_input_data = + reinterpret_cast(shuffled_input_workspace_data); for (int c = 0; c < output_depth; c += 4) { - const int8* shuffled_input_ptr = shuffled_input_data; + const int8_t* shuffled_input_ptr = shuffled_input_data; // Accumulation loop. // Internal accumulation. // Initialize accumulator with the bias-value. - int32 accum[4][4]; + int32_t accum[4][4]; for (int i = 0; i < 4; i++) { for (int b = 0; b < 4; b++) { accum[i][b] = 0; @@ -678,8 +678,8 @@ inline void ShuffledFullyConnectedWorkerImpl( for (int i = 0; i < 4; i++) { for (int b = 0; b < 4; b++) { for (int j = 0; j < 16; j++) { - int8 input_val = shuffled_input_ptr[16 * b + j]; - int8 weights_val = shuffled_weights_ptr[16 * i + j]; + int8_t input_val = shuffled_input_ptr[16 * b + j]; + int8_t weights_val = shuffled_weights_ptr[16 * i + j]; accum[i][b] += weights_val * input_val; } } @@ -691,13 +691,13 @@ inline void ShuffledFullyConnectedWorkerImpl( for (int b = 0; b < 4; b++) { // Add bias value int acc = accum[i][b] + bias_data[c + i]; - // Down-scale the final int32 accumulator to the scale used by our + // Down-scale the final int32_t accumulator to the scale used by our // (16-bit, typically 3 integer bits) fixed-point format. The // quantized multiplier and shift here have been pre-computed offline // (e.g. by toco). acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); - // Saturate, cast to int16, and store to output array. + // Saturate, cast to int16_t, and store to output array. acc = std::max(acc, -32768); acc = std::min(acc, 32767); output_ptr[b * output_stride + c + i] = acc; @@ -714,13 +714,13 @@ inline void ShuffledFullyConnectedWorkerImpl( // Wraps ShuffledFullyConnectedWorkerImpl into a Task class // to allow using gemmlowp's threadpool. struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task { - ShuffledFullyConnectedWorkerTask(const uint8* input_data, - const int8* shuffled_weights_data, + ShuffledFullyConnectedWorkerTask(const uint8_t* input_data, + const int8_t* shuffled_weights_data, int batches, int output_depth, int output_stride, int accum_depth, - const int32* bias_data, - int32 output_multiplier, int output_shift, - int16* output_data) + const int32_t* bias_data, + int32_t output_multiplier, int output_shift, + int16_t* output_data) : input_data_(input_data), shuffled_weights_data_(shuffled_weights_data), batches_(batches), @@ -739,30 +739,30 @@ struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task { output_shift_, output_data_); } - const uint8* input_data_; - const int8* shuffled_weights_data_; + const uint8_t* input_data_; + const int8_t* shuffled_weights_data_; int batches_; int output_depth_; int output_stride_; int accum_depth_; - const int32* bias_data_; - int32 output_multiplier_; + const int32_t* bias_data_; + int32_t output_multiplier_; int output_shift_; - int16* output_data_; + int16_t* output_data_; }; inline void ShuffledFullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& weights_shape, - const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - int16* output_data, uint8* shuffled_input_workspace_data, + const uint8_t* input_data, const RuntimeShape& weights_shape, + const uint8_t* shuffled_weights_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int16_t* output_data, uint8_t* shuffled_input_workspace_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit"); - const int32 output_multiplier = params.output_multiplier; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_max, 32767); TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); @@ -782,11 +782,11 @@ inline void ShuffledFullyConnected( TFLITE_DCHECK((accum_depth % 16) == 0); TFLITE_DCHECK((output_depth % 4) == 0); // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd) - // so that just reinterpreting them as int8 values is equivalent to + // so that just reinterpreting them as int8_t values is equivalent to // subtracting 128 from them, thus implementing for free the subtraction of // the zero_point value 128. - const int8* int8_shuffled_weights_data = - reinterpret_cast(shuffled_weights_data); + const int8_t* int8_shuffled_weights_data = + reinterpret_cast(shuffled_weights_data); // Shuffling and xoring of input activations into the workspace buffer if (batches == 1) { @@ -803,12 +803,12 @@ inline void ShuffledFullyConnected( } #endif } else if (batches == 4) { - uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data; + uint8_t* shuffled_input_workspace_ptr = shuffled_input_workspace_data; int c = 0; #ifdef USE_NEON const uint8x16_t signbit = vdupq_n_u8(0x80); for (c = 0; c < accum_depth; c += 16) { - const uint8* src_data_ptr = input_data + c; + const uint8_t* src_data_ptr = input_data + c; uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth); uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth); uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth); @@ -826,13 +826,13 @@ inline void ShuffledFullyConnected( #else for (c = 0; c < accum_depth; c += 16) { for (int b = 0; b < 4; b++) { - const uint8* src_data_ptr = input_data + b * accum_depth + c; + const uint8_t* src_data_ptr = input_data + b * accum_depth + c; for (int j = 0; j < 16; j++) { - uint8 src_val = *src_data_ptr++; + uint8_t src_val = *src_data_ptr++; // Flip the sign bit, so that the kernel will only need to - // reinterpret these uint8 values as int8, getting for free the + // reinterpret these uint8_t values as int8_t, getting for free the // subtraction of the zero_point value 128. - uint8 dst_val = src_val ^ 0x80; + uint8_t dst_val = src_val ^ 0x80; *shuffled_input_workspace_ptr++ = dst_val; } } @@ -930,7 +930,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, ruy::profiler::ScopeLabel label("Conv"); // NB: the float 0.0f value is represented by all zero bytes. - const uint8 float_zero_byte = 0x00; + const uint8_t float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; const int filter_width = filter_shape.Dims(2); @@ -1117,7 +1117,7 @@ inline void HybridConvPerChannel( TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - const int8* gemm_input_data = nullptr; + const int8_t* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; const int filter_width = filter_shape.Dims(2); const int filter_height = filter_shape.Dims(1); @@ -1168,17 +1168,17 @@ inline void HybridConvPerChannel( } } - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = filter_rows; lhs_params.cols = filter_cols; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; - cpu_backend_gemm::MatrixParams rhs_params; + cpu_backend_gemm::MatrixParams rhs_params; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.rows = gemm_input_rows; rhs_params.cols = gemm_input_cols; - cpu_backend_gemm::MatrixParams dst_params; + cpu_backend_gemm::MatrixParams dst_params; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.rows = output_rows; dst_params.cols = output_cols; @@ -1210,29 +1210,29 @@ inline void HybridConvPerChannel( } inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, const RuntimeShape& im2col_shape, - uint8* im2col_data, CpuBackendContext* cpu_backend_context) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, const RuntimeShape& im2col_shape, + uint8_t* im2col_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("Conv/8bit"); const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int dilation_width_factor = params.dilation_width_factor; const int dilation_height_factor = params.dilation_height_factor; - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - const uint8* gemm_input_data = nullptr; + const uint8_t* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; const int filter_width = filter_shape.Dims(2); const int filter_height = filter_shape.Dims(1); @@ -1287,22 +1287,22 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = filter_rows; lhs_params.cols = filter_cols; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.zero_point = -filter_offset; - cpu_backend_gemm::MatrixParams rhs_params; + cpu_backend_gemm::MatrixParams rhs_params; rhs_params.rows = gemm_input_rows; rhs_params.cols = gemm_input_cols; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.zero_point = -input_offset; - cpu_backend_gemm::MatrixParams dst_params; + cpu_backend_gemm::MatrixParams dst_params; dst_params.rows = output_rows; dst_params.cols = output_cols; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.zero_point = output_offset; - cpu_backend_gemm::GemmParams gemm_params; + cpu_backend_gemm::GemmParams gemm_params; gemm_params.bias = bias_data; gemm_params.clamp_min = output_activation_min; gemm_params.clamp_max = output_activation_max; @@ -1433,37 +1433,37 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params, inline void L2Normalization(const tflite::L2NormalizationParams& op_params, const RuntimeShape& input_shape, - const uint8* input_data, + const uint8_t* input_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { ruy::profiler::ScopeLabel label("L2Normalization/8bit"); const int trailing_dim = input_shape.DimensionsCount() - 1; const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int32 input_zero_point = op_params.input_zero_point; + const int32_t input_zero_point = op_params.input_zero_point; for (int i = 0; i < outer_size; ++i) { - int32 square_l2_norm = 0; + int32_t square_l2_norm = 0; for (int c = 0; c < depth; c++) { // Note that input_data advances by depth in the second pass below. - int32 diff = input_data[c] - input_zero_point; + int32_t diff = input_data[c] - input_zero_point; square_l2_norm += diff * diff; } // TODO(b/29395854): add clamping to TOCO and TF Lite kernel // for all zero tensors in the input_data - int32 inv_l2norm_multiplier; + int32_t inv_l2norm_multiplier; int inv_l2norm_shift; GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift, &inv_l2norm_multiplier, &inv_l2norm_shift); for (int c = 0; c < depth; c++) { - int32 diff = *input_data - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32_t diff = *input_data - input_zero_point; + int32_t rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - *output_data = static_cast(output_val); + int32_t unclamped_output_val = 128 + rescaled_diff; + int32_t output_val = std::min(255, std::max(0, unclamped_output_val)); + *output_data = static_cast(output_val); ++input_data; ++output_data; } @@ -1534,8 +1534,8 @@ inline void Add(const ArithmeticParams& params, // Element-wise add that can often be used for inner loop of broadcast add as // well as the non-broadcast add. inline void AddElementwise(int size, const ArithmeticParams& params, - const uint8* input1_data, const uint8* input2_data, - uint8* output_data) { + const uint8_t* input1_data, + const uint8_t* input2_data, uint8_t* output_data) { ruy::profiler::ScopeLabel label("AddElementwise/8bit"); int i = 0; TFLITE_DCHECK_GT(params.input1_offset, -256); @@ -1600,25 +1600,25 @@ inline void AddElementwise(int size, const ArithmeticParams& params, #endif // NEON for (; i < size; ++i) { - const int32 input1_val = params.input1_offset + input1_data[i]; - const int32 input2_val = params.input2_offset + input2_data[i]; - const int32 shifted_input1_val = input1_val * (1 << params.left_shift); - const int32 shifted_input2_val = input2_val * (1 << params.left_shift); - const int32 scaled_input1_val = + const int32_t input1_val = params.input1_offset + input1_data[i]; + const int32_t input2_val = params.input2_offset + input2_data[i]; + const int32_t shifted_input1_val = input1_val * (1 << params.left_shift); + const int32_t shifted_input2_val = input2_val * (1 << params.left_shift); + const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( shifted_input1_val, params.input1_multiplier, params.input1_shift); - const int32 scaled_input2_val = + const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( shifted_input2_val, params.input2_multiplier, params.input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = + const int32_t raw_sum = scaled_input1_val + scaled_input2_val; + const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( raw_sum, params.output_multiplier, params.output_shift) + params.output_offset; - const int32 clamped_output = + const int32_t clamped_output = std::min(params.quantized_activation_max, std::max(params.quantized_activation_min, raw_output)); - output_data[i] = static_cast(clamped_output); + output_data[i] = static_cast(clamped_output); } } @@ -1626,8 +1626,8 @@ inline void AddElementwise(int size, const ArithmeticParams& params, // broadcast add, so that, for example, scalar-broadcast with batch will still // be fast. inline void AddScalarBroadcast(int size, const ArithmeticParams& params, - uint8 input1_data, const uint8* input2_data, - uint8* output_data) { + uint8_t input1_data, const uint8_t* input2_data, + uint8_t* output_data) { using gemmlowp::RoundingDivideByPOT; ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit"); @@ -1699,28 +1699,28 @@ inline void AddScalarBroadcast(int size, const ArithmeticParams& params, if (i < size) { // Process broadcast scalar. - const int32 input1_val = params.input1_offset + input1_data; - const int32 shifted_input1_val = input1_val * (1 << params.left_shift); - const int32 scaled_input1_val = + const int32_t input1_val = params.input1_offset + input1_data; + const int32_t shifted_input1_val = input1_val * (1 << params.left_shift); + const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( shifted_input1_val, params.input1_multiplier, params.input1_shift); for (; i < size; ++i) { - const int32 input2_val = params.input2_offset + input2_data[i]; - const int32 shifted_input2_val = input2_val * (1 << params.left_shift); - const int32 scaled_input2_val = + const int32_t input2_val = params.input2_offset + input2_data[i]; + const int32_t shifted_input2_val = input2_val * (1 << params.left_shift); + const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( shifted_input2_val, params.input2_multiplier, params.input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = + const int32_t raw_sum = scaled_input1_val + scaled_input2_val; + const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( raw_sum, params.output_multiplier, params.output_shift) + params.output_offset; - const int32 clamped_output = + const int32_t clamped_output = std::min(params.quantized_activation_max, std::max(params.quantized_activation_min, raw_output)); - output_data[i] = static_cast(clamped_output); + output_data[i] = static_cast(clamped_output); } } } @@ -1759,9 +1759,9 @@ inline void AddScalarBroadcast(int size, const ArithmeticParams& params, } inline void Add(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const uint8* input1_data, - const RuntimeShape& input2_shape, const uint8* input2_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input1_shape, const uint8_t* input1_data, + const RuntimeShape& input2_shape, const uint8_t* input2_data, + const RuntimeShape& output_shape, uint8_t* output_data) { TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); ruy::profiler::ScopeLabel label("Add/8bit"); @@ -1776,9 +1776,9 @@ inline void Add(const ArithmeticParams& params, } inline void Add(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const int16* input1_data, - const RuntimeShape& input2_shape, const int16* input2_data, - const RuntimeShape& output_shape, int16* output_data) { + const RuntimeShape& input1_shape, const int16_t* input1_data, + const RuntimeShape& input2_shape, const int16_t* input2_data, + const RuntimeShape& output_shape, int16_t* output_data) { ruy::profiler::ScopeLabel label("Add/Int16"); TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); @@ -1786,14 +1786,15 @@ inline void Add(const ArithmeticParams& params, const int input1_shift = params.input1_shift; const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); - const int16 output_activation_min = params.quantized_activation_min; - const int16 output_activation_max = params.quantized_activation_max; + const int16_t output_activation_min = params.quantized_activation_min; + const int16_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); TFLITE_DCHECK_LE(input1_shift, 0); TFLITE_DCHECK_LE(params.input2_shift, 0); - const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; - const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; + const int16_t* not_shift_input = + input1_shift == 0 ? input1_data : input2_data; + const int16_t* shift_input = input1_shift == 0 ? input2_data : input1_data; const int input_right_shift = input1_shift == 0 ? -params.input2_shift : -input1_shift; @@ -1805,8 +1806,8 @@ inline void Add(const ArithmeticParams& params, F0 scaled_input = F0::FromRaw( gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); - const int16 raw_output = result.raw(); - const int16 clamped_output = std::min( + const int16_t raw_output = result.raw(); + const int16_t clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = clamped_output; } @@ -1867,11 +1868,11 @@ inline void BroadcastAddDispatch( inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, const RuntimeShape& unswitched_input1_shape, - const uint8* unswitched_input1_data, + const uint8_t* unswitched_input1_data, const RuntimeShape& unswitched_input2_shape, - const uint8* unswitched_input2_data, + const uint8_t* unswitched_input2_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { BroadcastAddDispatch(unswitched_params, unswitched_input1_shape, unswitched_input1_data, unswitched_input2_shape, unswitched_input2_data, output_shape, output_data); @@ -1946,6 +1947,63 @@ inline void MulElementwise(int size, const ArithmeticParams& params, } } +inline void MulElementwise(int32_t n, const ArithmeticParams& params, + const int32_t* __restrict lhs, + const int32_t* __restrict rhs, + int32_t* __restrict out) { + const int32_t activation_min_val = params.quantized_activation_min; + const int32_t activation_max_val = params.quantized_activation_max; + + int32_t i = 0; + +#ifdef USE_NEON + const int32x4_t activation_min = vdupq_n_s32(activation_min_val); + const int32x4_t activation_max = vdupq_n_s32(activation_max_val); + + // Ewise Mul 16 elements at a time using 4 4-wide vector registers per loop. + for (; i <= n - 16; i += 16) { + // Load. + const int32x4_t lhs_reg = vld1q_s32(lhs + i); + const int32x4_t lhs_reg2 = vld1q_s32(lhs + i + 4); + const int32x4_t lhs_reg3 = vld1q_s32(lhs + i + 8); + const int32x4_t lhs_reg4 = vld1q_s32(lhs + i + 12); + + const int32x4_t rhs_reg = vld1q_s32(rhs + i); + const int32x4_t rhs_reg2 = vld1q_s32(rhs + i + 4); + const int32x4_t rhs_reg3 = vld1q_s32(rhs + i + 8); + const int32x4_t rhs_reg4 = vld1q_s32(rhs + i + 12); + + // Multiply. + const int32x4_t mul_reg = vmulq_s32(lhs_reg, rhs_reg); + const int32x4_t mul_reg2 = vmulq_s32(lhs_reg2, rhs_reg2); + const int32x4_t mul_reg3 = vmulq_s32(lhs_reg3, rhs_reg3); + const int32x4_t mul_reg4 = vmulq_s32(lhs_reg4, rhs_reg4); + + // Apply activation. + const int32x4_t max_reg = vminq_s32(activation_max, mul_reg); + const int32x4_t max_reg2 = vminq_s32(activation_max, mul_reg2); + const int32x4_t max_reg3 = vminq_s32(activation_max, mul_reg3); + const int32x4_t max_reg4 = vminq_s32(activation_max, mul_reg4); + const int32x4_t min_reg = vmaxq_s32(activation_min, max_reg); + const int32x4_t min_reg2 = vmaxq_s32(activation_min, max_reg2); + const int32x4_t min_reg3 = vmaxq_s32(activation_min, max_reg3); + const int32x4_t min_reg4 = vmaxq_s32(activation_min, max_reg4); + + // Store. + vst1q_s32(out + i, min_reg); + vst1q_s32(out + i + 4, min_reg2); + vst1q_s32(out + i + 8, min_reg3); + vst1q_s32(out + i + 12, min_reg4); + } +#endif + + // This will handle leftovers when n is not aligned to 4 elements. + for (; i < n; ++i) { + out[i] = ActivationFunctionWithMinMax(lhs[i] * rhs[i], activation_min_val, + activation_max_val); + } +} + inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const float* input1_data, const RuntimeShape& input2_shape, const float* input2_data, @@ -1958,30 +2016,25 @@ inline void Mul(const ArithmeticParams& params, } inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const int32* input1_data, - const RuntimeShape& input2_shape, const int32* input2_data, - const RuntimeShape& output_shape, int32* output_data) { - ruy::profiler::ScopeLabel label("Mul/int32/activation"); + const RuntimeShape& input1_shape, const int32_t* input1_data, + const RuntimeShape& input2_shape, const int32_t* input2_data, + const RuntimeShape& output_shape, int32_t* output_data) { + ruy::profiler::ScopeLabel label("Mul/int32_t/activation"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - for (int i = 0; i < flat_size; ++i) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] * input2_data[i], output_activation_min, - output_activation_max); - } + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); } inline void MulNoActivation(const ArithmeticParams& params, const RuntimeShape& input1_shape, - const int32* input1_data, + const int32_t* input1_data, const RuntimeShape& input2_shape, - const int32* input2_data, + const int32_t* input2_data, const RuntimeShape& output_shape, - int32* output_data) { - ruy::profiler::ScopeLabel label("Mul/int32"); + int32_t* output_data) { + ruy::profiler::ScopeLabel label("Mul/int32_t"); auto input1_map = MapAsVector(input1_data, input1_shape); auto input2_map = MapAsVector(input2_data, input2_shape); @@ -2002,9 +2055,9 @@ inline void MulNoActivation(const ArithmeticParams& params, } inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const int16* input1_data, - const RuntimeShape& input2_shape, const int16* input2_data, - const RuntimeShape& output_shape, int16* output_data) { + const RuntimeShape& input1_shape, const int16_t* input1_data, + const RuntimeShape& input2_shape, const int16_t* input2_data, + const RuntimeShape& output_shape, int16_t* output_data) { ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation"); // This is a copy of the reference implementation. We do not currently have a // properly optimized version. @@ -2023,15 +2076,15 @@ inline void Mul(const ArithmeticParams& params, } inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const int16* input1_data, - const RuntimeShape& input2_shape, const int16* input2_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input1_shape, const int16_t* input1_data, + const RuntimeShape& input2_shape, const int16_t* input2_data, + const RuntimeShape& output_shape, uint8_t* output_data) { ruy::profiler::ScopeLabel label("Mul/Int16Uint8"); // This is a copy of the reference implementation. We do not currently have a // properly optimized version. - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - const int32 output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + const int32_t output_offset = params.output_offset; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int flat_size = @@ -2043,12 +2096,12 @@ inline void Mul(const ArithmeticParams& params, F0 unclamped_result = F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); - int16 rescaled_result = + int16_t rescaled_result = gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8); - int16 clamped_result = - std::min(output_activation_max - output_offset, rescaled_result); - clamped_result = - std::max(output_activation_min - output_offset, clamped_result); + int16_t clamped_result = std::min( + output_activation_max - output_offset, rescaled_result); + clamped_result = std::max(output_activation_min - output_offset, + clamped_result); output_data[i] = output_offset + clamped_result; } } @@ -2056,8 +2109,8 @@ inline void Mul(const ArithmeticParams& params, // Element-wise mul that can often be used for inner loop of broadcast Mul as // well as the non-broadcast Mul. inline void MulElementwise(int size, const ArithmeticParams& params, - const uint8* input1_data, const uint8* input2_data, - uint8* output_data) { + const uint8_t* input1_data, + const uint8_t* input2_data, uint8_t* output_data) { int i = 0; TFLITE_DCHECK_GT(params.input1_offset, -256); TFLITE_DCHECK_LT(params.input1_offset, 256); @@ -2115,25 +2168,26 @@ inline void MulElementwise(int size, const ArithmeticParams& params, #endif // NEON for (; i < size; ++i) { - const int32 input1_val = params.input1_offset + input1_data[i]; - const int32 input2_val = params.input2_offset + input2_data[i]; - const int32 unclamped_result = + const int32_t input1_val = params.input1_offset + input1_data[i]; + const int32_t input2_val = params.input2_offset + input2_data[i]; + const int32_t unclamped_result = params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val, params.output_multiplier, params.output_shift); - const int32 clamped_output = + const int32_t clamped_output = std::min(params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result)); - output_data[i] = static_cast(clamped_output); + output_data[i] = static_cast(clamped_output); } } // Broadcast mul that can often be used for inner loop of broadcast Mul. inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, - const uint8 broadcast_value, - const uint8* input2_data, uint8* output_data) { - const int16 input1_val = params.input1_offset + broadcast_value; + const uint8_t broadcast_value, + const uint8_t* input2_data, + uint8_t* output_data) { + const int16_t input1_val = params.input1_offset + broadcast_value; int i = 0; TFLITE_DCHECK_GT(params.input1_offset, -256); @@ -2185,16 +2239,16 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, #endif // NEON for (; i < size; ++i) { - const int32 input2_val = params.input2_offset + input2_data[i]; - const int32 unclamped_result = + const int32_t input2_val = params.input2_offset + input2_data[i]; + const int32_t unclamped_result = params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val, params.output_multiplier, params.output_shift); - const int32 clamped_output = + const int32_t clamped_output = std::min(params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result)); - output_data[i] = static_cast(clamped_output); + output_data[i] = static_cast(clamped_output); } } @@ -2232,9 +2286,9 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, } inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const uint8* input1_data, - const RuntimeShape& input2_shape, const uint8* input2_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input1_shape, const uint8_t* input1_data, + const RuntimeShape& input2_shape, const uint8_t* input2_data, + const RuntimeShape& output_shape, uint8_t* output_data) { TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); ruy::profiler::ScopeLabel label("Mul/8bit"); @@ -2265,11 +2319,11 @@ inline void BroadcastMulDispatch( inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, const RuntimeShape& unswitched_input1_shape, - const uint8* unswitched_input1_data, + const uint8_t* unswitched_input1_data, const RuntimeShape& unswitched_input2_shape, - const uint8* unswitched_input2_data, + const uint8_t* unswitched_input2_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { BroadcastMulDispatch(unswitched_params, unswitched_input1_shape, unswitched_input1_data, unswitched_input2_shape, unswitched_input2_data, output_shape, output_data); @@ -2347,11 +2401,11 @@ void BroadcastDivSlow(const ArithmeticParams& params, template inline void BroadcastDivSlow(const ArithmeticParams& params, const RuntimeShape& unextended_input1_shape, - const uint8* input1_data, + const uint8_t* input1_data, const RuntimeShape& unextended_input2_shape, - const uint8* input2_data, + const uint8_t* input2_data, const RuntimeShape& unextended_output_shape, - uint8* output_data) { + uint8_t* output_data) { TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N); TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N); TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N); @@ -2372,9 +2426,9 @@ inline void BroadcastDivSlow(const ArithmeticParams& params, TFLITE_DCHECK_LT(params.output_offset, 256); auto div_func = [&](int indexes[N]) { - int32 input1_val = + int32_t input1_val = params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)]; - int32 input2_val = + int32_t input2_val = params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)]; TFLITE_DCHECK_NE(input2_val, 0); if (input2_val < 0) { @@ -2384,20 +2438,21 @@ inline void BroadcastDivSlow(const ArithmeticParams& params, input2_val = -input2_val; } int recip_shift; - const int32 input2_inv = GetReciprocal(input2_val, 31, &recip_shift); + const int32_t input2_inv = GetReciprocal(input2_val, 31, &recip_shift); const int headroom = CountLeadingSignBits(input1_val); - const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne( - input1_val, input2_inv, headroom); + const int32_t unscaled_quotient = + MultiplyByQuantizedMultiplierGreaterThanOne(input1_val, input2_inv, + headroom); const int total_shift = params.output_shift - recip_shift - headroom; - const int32 unclamped_result = + const int32_t unclamped_result = params.output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( unscaled_quotient, params.output_multiplier, total_shift); - const int32 clamped_output = + const int32_t clamped_output = std::min(params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result)); output_data[SubscriptToIndex(output_desc, indexes)] = - static_cast(clamped_output); + static_cast(clamped_output); }; NDOpsHelper(output_desc, div_func); } @@ -2578,25 +2633,25 @@ inline void LstmCell( template inline void LstmCell( const LstmCellParams& params, const RuntimeShape& unextended_input_shape, - const uint8* input_data_uint8, + const uint8_t* input_data_uint8, const RuntimeShape& unextended_prev_activ_shape, - const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape, - const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape, - const int32* bias_data_int32, + const uint8_t* prev_activ_data_uint8, const RuntimeShape& weights_shape, + const uint8_t* weights_data_uint8, + const RuntimeShape& unextended_bias_shape, const int32_t* bias_data_int32, const RuntimeShape& unextended_prev_state_shape, - const int16* prev_state_data_int16, + const int16_t* prev_state_data_int16, const RuntimeShape& unextended_output_state_shape, - int16* output_state_data_int16, + int16_t* output_state_data_int16, const RuntimeShape& unextended_output_activ_shape, - uint8* output_activ_data_uint8, + uint8_t* output_activ_data_uint8, const RuntimeShape& unextended_concat_temp_shape, - uint8* concat_temp_data_uint8, + uint8_t* concat_temp_data_uint8, const RuntimeShape& unextended_activ_temp_shape, - int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) { + int16_t* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label( "LstmCell/quantized (8bit external, 16bit internal)"); - int32 weights_zero_point = params.weights_zero_point; - int32 accum_multiplier = params.accum_multiplier; + int32_t weights_zero_point = params.weights_zero_point; + int32_t accum_multiplier = params.accum_multiplier; int accum_shift = params.accum_shift; TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); @@ -2651,8 +2706,8 @@ inline void LstmCell( TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth); // Depth-concatenate prev_activ and input data together. - uint8 const* concat_input_arrays_data[2] = {input_data_uint8, - prev_activ_data_uint8}; + uint8_t const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape, &prev_activ_shape}; tflite::ConcatenationParams concat_params; @@ -2667,22 +2722,22 @@ inline void LstmCell( // integers, and the output is 16-bit fixed-point with 3 integer bits so // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that // is explained in the function comment above. - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = fc_output_depth; lhs_params.cols = fc_accum_depth; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.zero_point = weights_zero_point; - cpu_backend_gemm::MatrixParams rhs_params; + cpu_backend_gemm::MatrixParams rhs_params; rhs_params.rows = fc_accum_depth; rhs_params.cols = fc_batches; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.zero_point = 128; - cpu_backend_gemm::MatrixParams dst_params; + cpu_backend_gemm::MatrixParams dst_params; dst_params.rows = fc_output_depth; dst_params.cols = fc_batches; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.zero_point = 0; - cpu_backend_gemm::GemmParams gemm_params; + cpu_backend_gemm::GemmParams gemm_params; gemm_params.bias = bias_data_int32; gemm_params.multiplier_fixedpoint = accum_multiplier; gemm_params.multiplier_exponent = accum_shift; @@ -2692,21 +2747,23 @@ inline void LstmCell( // Rest of the LSTM cell: tanh and logistic math functions, and some adds // and muls, all done in 16-bit fixed-point. - const int16* input_gate_input_ptr = activ_temp_data_int16; - const int16* input_modulation_gate_input_ptr = + const int16_t* input_gate_input_ptr = activ_temp_data_int16; + const int16_t* input_modulation_gate_input_ptr = activ_temp_data_int16 + output_depth; - const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth; - const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth; - const int16* prev_state_ptr = prev_state_data_int16; - int16* output_state_data_ptr = output_state_data_int16; - uint8* output_activ_data_ptr = output_activ_data_uint8; + const int16_t* forget_gate_input_ptr = + activ_temp_data_int16 + 2 * output_depth; + const int16_t* output_gate_input_ptr = + activ_temp_data_int16 + 3 * output_depth; + const int16_t* prev_state_ptr = prev_state_data_int16; + int16_t* output_state_data_ptr = output_state_data_int16; + uint8_t* output_activ_data_ptr = output_activ_data_uint8; for (int b = 0; b < outer_size; ++b) { int c = 0; #ifdef GEMMLOWP_NEON for (; c <= output_depth - 8; c += 8) { // Define the fixed-point data types that we will use here. All use - // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // int16_t as the underlying integer type i.e. all are 16-bit fixed-point. // They only differ by the number of integral vs. fractional bits, // determining the range of values that they can represent. // @@ -2780,7 +2837,7 @@ inline void LstmCell( #endif for (; c < output_depth; ++c) { // Define the fixed-point data types that we will use here. All use - // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // int16_t as the underlying integer type i.e. all are 16-bit fixed-point. // They only differ by the number of integral vs. fractional bits, // determining the range of values that they can represent. // @@ -2837,10 +2894,10 @@ inline void LstmCell( *output_state_data_ptr++ = new_state.raw(); // Down-scale the output activations to 8-bit integers, saturating, // and store back to memory. - int16 rescaled_output_activ = + int16_t rescaled_output_activ = gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); - int16 clamped_output_activ = - std::max(-128, std::min(127, rescaled_output_activ)); + int16_t clamped_output_activ = std::max( + -128, std::min(127, rescaled_output_activ)); *output_activ_data_ptr++ = 128 + clamped_output_activ; } input_gate_input_ptr += 3 * output_depth; @@ -2923,8 +2980,9 @@ inline bool AveragePool(const PoolParams& params, inline bool AveragePool(const PoolParams& params, const RuntimeShape& input_shape, - const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { + const uint8_t* input_data, + const RuntimeShape& output_shape, + uint8_t* output_data) { ruy::profiler::ScopeLabel label("AveragePool/8bit"); // Here, and in other pooling ops, in order to maintain locality of reference, @@ -2947,7 +3005,7 @@ inline bool AveragePool(const PoolParams& params, const int stride_height = params.stride_height; const int stride_width = params.stride_width; - uint32 acc[kPoolingAccTrancheSize]; + uint32_t acc[kPoolingAccTrancheSize]; for (int batch = 0; batch < batches; ++batch) { // We proceed through the depth in tranches (see comment above). The // depth_base is the depth at the beginning of the tranche. The @@ -2972,15 +3030,15 @@ inline bool AveragePool(const PoolParams& params, (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); if (filter_count == 0) return false; memset(acc, 0, tranche_depth * sizeof(acc[0])); - const uint8* input_ptr = + const uint8_t* input_ptr = input_data + depth_base + depth * (in_x_origin + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = + const uint8_t* input_row_ptr = input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { - const uint8* input_channel_ptr = input_row_ptr; + const uint8_t* input_channel_ptr = input_row_ptr; int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 16; channel += 16) { @@ -3016,14 +3074,14 @@ inline bool AveragePool(const PoolParams& params, input_row_ptr += depth; } } - uint8* output_ptr = output_data + Offset(output_shape, batch, out_y, - out_x, depth_base); + uint8_t* output_ptr = output_data + Offset(output_shape, batch, out_y, + out_x, depth_base); int channel = 0; #ifdef USE_NEON #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ if (filter_count == FILTER_COUNT) { \ for (; channel <= tranche_depth - 8; channel += 8) { \ - uint16 buf[8]; \ + uint16_t buf[8]; \ for (int i = 0; i < 8; i++) { \ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ } \ @@ -3037,7 +3095,7 @@ inline bool AveragePool(const PoolParams& params, AVGPOOL_DIVIDING_BY(15) #undef AVGPOOL_DIVIDING_BY for (; channel <= tranche_depth - 8; channel += 8) { - uint16 buf[8]; + uint16_t buf[8]; for (int i = 0; i < 8; i++) { buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; } @@ -3048,10 +3106,10 @@ inline bool AveragePool(const PoolParams& params, } #endif for (; channel < tranche_depth; ++channel) { - uint16 a = (acc[channel] + filter_count / 2) / filter_count; - a = std::max(a, params.quantized_activation_min); - a = std::min(a, params.quantized_activation_max); - output_ptr[channel] = static_cast(a); + uint16_t a = (acc[channel] + filter_count / 2) / filter_count; + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); + output_ptr[channel] = static_cast(a); } } } @@ -3115,8 +3173,8 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, } inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& output_shape, - uint8* output_data) { + const uint8_t* input_data, const RuntimeShape& output_shape, + uint8_t* output_data) { ruy::profiler::ScopeLabel label("MaxPool/8bit"); // Here, and in other pooling ops, in order to maintain locality of reference, @@ -3139,7 +3197,7 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, const int stride_height = params.stride_height; const int stride_width = params.stride_width; - uint8 acc[kPoolingAccTrancheSize]; + uint8_t acc[kPoolingAccTrancheSize]; for (int batch = 0; batch < batches; ++batch) { // We proceed through the depth in tranches (see comment above). The // depth_base is the depth at the beginning of the tranche. The @@ -3161,15 +3219,15 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, const int filter_y_end = std::min(params.filter_height, input_height - in_y_origin); memset(acc, 0, tranche_depth * sizeof(acc[0])); - const uint8* input_ptr = + const uint8_t* input_ptr = input_data + depth_base + depth * (in_x_origin + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = + const uint8_t* input_row_ptr = input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { - const uint8* input_channel_ptr = input_row_ptr; + const uint8_t* input_channel_ptr = input_row_ptr; int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 16; channel += 16) { @@ -3194,8 +3252,8 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, input_row_ptr += depth; } } - uint8* output_ptr = output_data + Offset(output_shape, batch, out_y, - out_x, depth_base); + uint8_t* output_ptr = output_data + Offset(output_shape, batch, out_y, + out_x, depth_base); int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 16; channel += 16) { @@ -3212,10 +3270,10 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, } #endif for (; channel < tranche_depth; ++channel) { - uint8 a = acc[channel]; - a = std::max(a, params.quantized_activation_min); - a = std::min(a, params.quantized_activation_max); - output_ptr[channel] = static_cast(a); + uint8_t a = acc[channel]; + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); + output_ptr[channel] = static_cast(a); } } } @@ -3498,7 +3556,7 @@ inline void Softmax(const SoftmaxParams& params, // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n) // // For quantization, `x` in our case is (input_q - input_zp) * input_s -// For uint8 case (int8 can be handled similarly), the range is [0, 255] +// For uint8_t case (int8_t can be handled similarly), the range is [0, 255] // // so if we let // CONST = (255 - input_zp) * input_s @@ -3508,7 +3566,7 @@ inline void Softmax(const SoftmaxParams& params, // sum(e^(input_q - 255) * input_s, 0...n) -------- (2) // // the good thing about (1) is it's within the range of (0, 1), so we can -// approximate its result with uint16. +// approximate its result with uint16_t. // (1) = uint8_out * 1 / 2^16. // // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16. @@ -3522,8 +3580,8 @@ inline void Softmax(const SoftmaxParams& params, // + // output_zp // -// We can actually further improve the performance by using uint8 instead of -// uint16. But that we may lose some accuracy, so we need to pay attention +// We can actually further improve the performance by using uint8_t instead of +// uint16_t. But that we may lose some accuracy, so we need to pay attention // to that. inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data, float input_scale, float beta) { @@ -3553,7 +3611,7 @@ inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) { input_value = veorq_u8(input_value, offset_dup); max_val_dup = vmaxq_u8(input_value, max_val_dup); } - max_val = std::max(max_val, static_cast(vmaxvq_u8(max_val_dup))); + max_val = std::max(max_val, static_cast(vmaxvq_u8(max_val_dup))); #endif for (; j < size; ++j) { @@ -3608,12 +3666,12 @@ inline void SoftmaxInt8LUT(const SoftmaxParams& params, const int32_t clamp_min = std::numeric_limits::min(); // Offset is used to interpret the input data "correctly". - // If the input is uint8, the data will be unchanged. - // If the input is int8, since it will be reinterpret as uint8. + // If the input is uint8_t, the data will be unchanged. + // If the input is int8_t, since it will be reinterpret as uint8_t. // e.g., - // int8 127 will be applied "offset" to become 255 in uint8. + // int8_t 127 will be applied "offset" to become 255 in uint8_t. uint8_t offset = 0; - if (std::is_same::value) { + if (std::is_same::value) { offset = 0x80; } @@ -3641,7 +3699,7 @@ inline void SoftmaxInt8LUT(const SoftmaxParams& params, // Find max quantized value. int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset); - int32 sum_exp = 0; + int32_t sum_exp = 0; const int32_t max_uint8 = std::numeric_limits::max(); const uint8_t table_offset = max_uint8 - max_val; @@ -3686,7 +3744,7 @@ inline void SoftmaxInt8LUT(const SoftmaxParams& params, const float inv_sum_exp = 1.0f / (sum_exp * params.scale); - int32 multiplier, shift; + int32_t multiplier, shift; QuantizeMultiplier(inv_sum_exp, &multiplier, &shift); // Normalize and quantize probabilities. @@ -3782,8 +3840,9 @@ inline void LogSoftmax(const SoftmaxParams& params, // Backwards compatibility. Less optimized than below version. inline void LogSoftmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input_shape, + const uint8_t* input_data, + const RuntimeShape& output_shape, uint8_t* output_data) { reference_ops::LogSoftmax(params, input_shape, input_data, output_shape, output_data); } @@ -3794,7 +3853,7 @@ inline void LogSoftmax(const SoftmaxParams& params, // // To handle quantization, first dequantize the inputs (from doing // e^(input scale * val) where we ignore the zero point since it cancels -// out during subtraction due to the ln) and do a rescale at the end to int8. +// out during subtraction due to the ln) and do a rescale at the end to int8_t. // // Notably this makes use of float and is intended as the optimized // form for quantized execution on CPU. For a fully integer version, @@ -3825,7 +3884,7 @@ inline void LogSoftmax(const SoftmaxParams& params, float input_scale, } float sum_exp = 0.0f; - const int32_t max_uint8 = std::numeric_limits::max(); + const int32_t max_uint8 = std::numeric_limits::max(); // Offset into table to compute exp(scale*(x - xmax)) instead of // exp(scale*(x)) to prevent overflow. const float* table_offset = ¶ms.table[max_uint8 - max_val]; @@ -3875,8 +3934,8 @@ inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape, } inline void Logistic(const LogisticParams& params, - const RuntimeShape& input_shape, const int16* input_data, - const RuntimeShape& output_shape, int16* output_data) { + const RuntimeShape& input_shape, const int16_t* input_data, + const RuntimeShape& output_shape, int16_t* output_data) { ruy::profiler::ScopeLabel label("Logistic/Int16"); const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -3884,8 +3943,8 @@ inline void Logistic(const LogisticParams& params, } int c = 0; - const int16* input_data_ptr = input_data; - int16* output_data_ptr = output_data; + const int16_t* input_data_ptr = input_data; + int16_t* output_data_ptr = output_data; #ifdef GEMMLOWP_NEON { // F0 uses 0 integer bits, range [-1, 1]. @@ -3988,8 +4047,8 @@ inline void Tanh(const TanhParams&, const RuntimeShape& input_shape, } inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& output_shape, - int16* output_data) { + const int16_t* input_data, const RuntimeShape& output_shape, + int16_t* output_data) { ruy::profiler::ScopeLabel label("Tanh/Int16"); const int input_left_shift = params.input_left_shift; // Support for shifts is limited until we have a parameterized version of @@ -4000,8 +4059,8 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, const int flat_size = MatchingFlatSize(input_shape, output_shape); int c = 0; - const int16* input_data_ptr = input_data; - int16* output_data_ptr = output_data; + const int16_t* input_data_ptr = input_data; + int16_t* output_data_ptr = output_data; #ifdef GEMMLOWP_NEON { // F0 uses 0 integer bits, range [-1, 1]. @@ -4201,11 +4260,14 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim, } template -inline void BatchToSpaceND( - const RuntimeShape& unextended_input1_shape, const T* input1_data, - const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, - const RuntimeShape& unextended_input3_shape, const int32* crops_data, - const RuntimeShape& unextended_output_shape, T* output_data) { +inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape, + const T* input1_data, + const RuntimeShape& unextended_input2_shape, + const int32_t* block_shape_data, + const RuntimeShape& unextended_input3_shape, + const int32_t* crops_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { ruy::profiler::ScopeLabel label("BatchToSpaceND"); TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3); @@ -4305,8 +4367,8 @@ TFLITE_NOINLINE void TypedMemset(void* ptr, T value, size_t num) { // equivalent to a simple input1_data. For Pad, it should point to a zero // value. // -// Note that two typenames are required, so that T=P=int32 is considered a -// specialization distinct from P=int32. +// Note that two typenames are required, so that T=P=int32_t is considered a +// specialization distinct from P=int32_t. template inline void PadImpl(const tflite::PadParams& op_params, const RuntimeShape& input_shape, const T* input_data, @@ -4449,11 +4511,11 @@ inline void Pad(const tflite::PadParams& op_params, output_data); } -// The second (pad-value) input can be int32 when, say, the first is uint8. +// The second (pad-value) input can be int32_t when, say, the first is uint8_t. template inline void Pad(const tflite::PadParams& op_params, const RuntimeShape& input_shape, const T* input_data, - const int32* pad_value_ptr, const RuntimeShape& output_shape, + const int32_t* pad_value_ptr, const RuntimeShape& output_shape, T* output_data) { const T converted_pad_value = static_cast(*pad_value_ptr); PadImpl(op_params, input_shape, input_data, &converted_pad_value, @@ -4463,9 +4525,9 @@ inline void Pad(const tflite::PadParams& op_params, // This version avoids conflicting template matching. template <> inline void Pad(const tflite::PadParams& op_params, - const RuntimeShape& input_shape, const int32* input_data, - const int32* pad_value_ptr, const RuntimeShape& output_shape, - int32* output_data) { + const RuntimeShape& input_shape, const int32_t* input_data, + const int32_t* pad_value_ptr, const RuntimeShape& output_shape, + int32_t* output_data) { PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape, output_data); } @@ -4474,15 +4536,15 @@ inline void Pad(const tflite::PadParams& op_params, // // This pad requires that (a) left and right paddings are in the 4D patterns // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or -// T is uint8. +// T is uint8_t. // // There are two versions of pad: Pad and PadV2. In PadV2 there is a second // scalar input that provides the padding value. Therefore pad_value_ptr can be // equivalent to a simple input1_data. For Pad, it should point to a zero // value. // -// Note that two typenames are required, so that T=P=int32 is considered a -// specialization distinct from P=int32. +// Note that two typenames are required, so that T=P=int32_t is considered a +// specialization distinct from P=int32_t. template inline void PadImageStyleMemset(const tflite::PadParams& op_params, const RuntimeShape& input_shape, @@ -4604,9 +4666,9 @@ inline void PadImageStyle(const tflite::PadParams& op_params, template inline void PadImageStyle(const tflite::PadParams& op_params, const RuntimeShape& input_shape, - const uint8* input_data, const P* pad_value_ptr, + const uint8_t* input_data, const P* pad_value_ptr, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr, output_shape, output_data); } @@ -4723,7 +4785,7 @@ inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data, } template -void TransposeIm2col(const ConvParams& params, uint8 zero_byte, +void TransposeIm2col(const ConvParams& params, uint8_t zero_byte, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& filter_shape, const RuntimeShape& output_shape, T* im2col_data) { @@ -4935,7 +4997,7 @@ inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size, int32_t output_zp, const int32_t output_min, const int32_t output_max, int32_t* scratch, uint8_t* output) { - ruy::profiler::ScopeLabel label("Quantize/uint8"); + ruy::profiler::ScopeLabel label("Quantize/uint8_t"); int i = 0; #ifdef USE_NEON @@ -5000,7 +5062,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, int32_t channel_size, int32_t total_size, int32_t output_zp, int32_t output_min, int32_t output_max, int32_t* scratch, int8_t* output) { - ruy::profiler::ScopeLabel label("Quantize/int8"); + ruy::profiler::ScopeLabel label("Quantize/int8_t"); // Here we're trying to quantize the raw accumulators: // output_channels @@ -5062,7 +5124,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, acc_2 = vmaxq_s32(acc_2, output_activation_min_vec); acc_2 = vminq_s32(acc_2, output_activation_max_vec); - // Saturating cast to int8 and store to destination. + // Saturating cast to int8_t and store to destination. const int16x4_t acc_s16_1 = vqmovn_s32(acc_1); const int16x4_t acc_s16_2 = vqmovn_s32(acc_2); const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2); @@ -5076,12 +5138,12 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, for (; c < channel_size; c++) { for (int n = 0; n < rows; ++n) { int loc = n * channel_size + c; - int32 acc = scratch[loc]; + int32_t acc = scratch[loc]; acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]); acc += output_zp; acc = std::max(acc, output_min); acc = std::min(acc, output_max); - output[loc] = static_cast(acc); + output[loc] = static_cast(acc); } } } @@ -5090,7 +5152,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, int32_t channel_size, int32_t total_size, int32_t output_zp, int32_t output_min, int32_t output_max, int32_t* scratch, int16_t* output) { - ruy::profiler::ScopeLabel label("Quantize(Single-rounding)/int16"); + ruy::profiler::ScopeLabel label("Quantize(Single-rounding)/int16_t"); // Here we're trying to quantize the raw accumulators: // output_channels @@ -5152,7 +5214,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, acc_2 = vmaxq_s32(acc_2, output_activation_min_vec); acc_2 = vminq_s32(acc_2, output_activation_max_vec); - // Saturating cast to int16 and store to destination. + // Saturating cast to int16_t and store to destination. const int16x4_t acc_s16_1 = vqmovn_s32(acc_1); const int16x4_t acc_s16_2 = vqmovn_s32(acc_2); vst1_s16(reinterpret_cast(output) + loc, acc_s16_1); @@ -5165,12 +5227,12 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, for (; c < channel_size; c++) { for (int n = 0; n < rows; ++n) { int loc = n * channel_size + c; - int32 acc = scratch[loc]; + int32_t acc = scratch[loc]; acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]); acc += output_zp; acc = std::max(acc, output_min); acc = std::min(acc, output_max); - output[loc] = static_cast(acc); + output[loc] = static_cast(acc); } } } @@ -5180,7 +5242,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, int32_t channel_size, int32_t total_size, int32_t output_zp, int32_t output_min, int32_t output_max, int32_t* scratch, int8_t* output) { - ruy::profiler::ScopeLabel label("Quantize/int8"); + ruy::profiler::ScopeLabel label("Quantize/int8_t"); // Here we're trying to quantize the raw accumulators: // output_channels @@ -5243,7 +5305,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, acc_2 = vmaxq_s32(acc_2, output_activation_min_vec); acc_2 = vminq_s32(acc_2, output_activation_max_vec); - // Saturating cast to int8 and store to destination. + // Saturating cast to int8_t and store to destination. const int16x4_t acc_s16_1 = vqmovn_s32(acc_1); const int16x4_t acc_s16_2 = vqmovn_s32(acc_2); const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2); @@ -5257,12 +5319,12 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, for (; c < channel_size; c++) { for (int n = 0; n < rows; ++n) { int loc = n * channel_size + c; - int32 acc = scratch[loc]; + int32_t acc = scratch[loc]; acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]); acc += output_zp; acc = std::max(acc, output_min); acc = std::min(acc, output_max); - output[loc] = static_cast(acc); + output[loc] = static_cast(acc); } } } @@ -5271,7 +5333,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, int32_t channel_size, int32_t total_size, int32_t output_zp, int32_t output_min, int32_t output_max, int32_t* scratch, int16_t* output) { - ruy::profiler::ScopeLabel label("Quantize(Double-rounding)/int16"); + ruy::profiler::ScopeLabel label("Quantize(Double-rounding)/int16_t"); // Here we're trying to quantize the raw accumulators: // output_channels @@ -5334,7 +5396,7 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, acc_2 = vmaxq_s32(acc_2, output_activation_min_vec); acc_2 = vminq_s32(acc_2, output_activation_max_vec); - // Saturating cast to int16 and store to destination. + // Saturating cast to int16_t and store to destination. const int16x4_t acc_s16_1 = vqmovn_s32(acc_1); const int16x4_t acc_s16_2 = vqmovn_s32(acc_2); vst1_s16(reinterpret_cast(output) + loc, acc_s16_1); @@ -5347,12 +5409,12 @@ inline void Quantize(const int32_t* multiplier, const int32_t* shift, for (; c < channel_size; c++) { for (int n = 0; n < rows; ++n) { int loc = n * channel_size + c; - int32 acc = scratch[loc]; + int32_t acc = scratch[loc]; acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]); acc += output_zp; acc = std::max(acc, output_min); acc = std::min(acc, output_max); - output[loc] = static_cast(acc); + output[loc] = static_cast(acc); } } } @@ -5363,11 +5425,11 @@ inline void TransposeConvV2( const ConvParams& params, const RuntimeShape& input_shape, const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape, const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, uint8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data, int32_t* scratch_data, CpuBackendContext* cpu_backend_context) { - ruy::profiler::ScopeLabel label("TransposeConvV2/uint8"); + ruy::profiler::ScopeLabel label("TransposeConvV2/uint8_t"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4); TFLITE_DCHECK(col2im_data); @@ -5396,8 +5458,8 @@ inline void TransposeConvV2( const int stride_height = params.stride_height; const int stride_width = params.stride_width; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; const int hwoi_ordered_filter_total_size = filter_height * filter_width * output_depth; @@ -5409,7 +5471,8 @@ inline void TransposeConvV2( lhs_params.zero_point = -params.weights_offset; int32_t* scratch_data_p = scratch_data; - std::fill_n(scratch_data, output_offset * batch_size, static_cast(0)); + std::fill_n(scratch_data, output_offset * batch_size, + static_cast(0)); for (int i = 0; i < batch_size; ++i) { cpu_backend_gemm::MatrixParams rhs_params; rhs_params.order = cpu_backend_gemm::Order::kColMajor; @@ -5450,9 +5513,9 @@ inline void TransposeConvV2( // version. inline void ResizeNearestNeighbor( const tflite::ResizeNearestNeighborParams& op_params, - const RuntimeShape& unextended_input_shape, const uint8* input_data, - const RuntimeShape& output_size_shape, const int32* output_size_data, - const RuntimeShape& unextended_output_shape, uint8* output_data) { + const RuntimeShape& unextended_input_shape, const uint8_t* input_data, + const RuntimeShape& output_size_shape, const int32_t* output_size_data, + const RuntimeShape& unextended_output_shape, uint8_t* output_data) { if (op_params.align_corners || op_params.half_pixel_centers) { // TODO(b/149823713): Add support for align_corners & half_pixel_centers in // this kernel. @@ -5469,42 +5532,42 @@ inline void ResizeNearestNeighbor( const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape); - int32 batches = MatchingDim(input_shape, 0, output_shape, 0); - int32 input_height = input_shape.Dims(1); - int32 input_width = input_shape.Dims(2); - int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + int32_t batches = MatchingDim(input_shape, 0, output_shape, 0); + int32_t input_height = input_shape.Dims(1); + int32_t input_width = input_shape.Dims(2); + int32_t depth = MatchingDim(input_shape, 3, output_shape, 3); // The Tensorflow version of this op allows resize on the width and height // axis only. TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2); - int32 output_height = output_size_data[0]; - int32 output_width = output_size_data[1]; + int32_t output_height = output_size_data[0]; + int32_t output_width = output_size_data[1]; // Convert scales to fixed-point with 16 fractional bits. We add 1 as an // error factor and to avoid zero scales. For example, with input_height = 1, // output_height = 3, the float scaling factor would be non-zero at 1/3. // With fixed-point, this is zero. - int32 height_scale = (input_height << 16) / output_height + 1; - int32 width_scale = (input_width << 16) / output_width + 1; + int32_t height_scale = (input_height << 16) / output_height + 1; + int32_t width_scale = (input_width << 16) / output_width + 1; const int col_offset = input_shape.Dims(3); const int row_offset = input_shape.Dims(2) * col_offset; const int batch_offset = input_shape.Dims(1) * row_offset; - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; + const uint8_t* input_ptr = input_data; + uint8_t* output_ptr = output_data; for (int b = 0; b < batches; ++b) { for (int y = 0; y < output_height; ++y) { - int32 in_y = std::min((y * height_scale) >> 16, input_height - 1); + int32_t in_y = std::min((y * height_scale) >> 16, input_height - 1); // Check offset calculation is the same as the reference version. See // function comment for details. We check using a non-float version of: // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast(input_height) // / output_height))); TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height); TFLITE_DCHECK_GE(y * input_height, in_y * output_height); - const uint8* y_input_ptr = input_ptr + in_y * row_offset; + const uint8_t* y_input_ptr = input_ptr + in_y * row_offset; for (int x = 0; x < output_width; ++x) { - int32 in_x = std::min((x * width_scale) >> 16, input_width - 1); + int32_t in_x = std::min((x * width_scale) >> 16, input_width - 1); // Check offset calculation is the same as the reference version. See // function comment for details. We check using a non-float version of: // TFLITE_DCHECK_EQ(in_y, @@ -5512,7 +5575,7 @@ inline void ResizeNearestNeighbor( // / output_width))); TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width); TFLITE_DCHECK_GE(x * input_width, in_x * output_width); - const uint8* x_input_ptr = y_input_ptr + in_x * col_offset; + const uint8_t* x_input_ptr = y_input_ptr + in_x * col_offset; memcpy(output_ptr, x_input_ptr, depth); output_ptr += depth; } @@ -6178,7 +6241,7 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, const uint8_t* input_data, const RuntimeShape& output_shape, float* output_data) { ruy::profiler::ScopeLabel label("Dequantize/Uint8"); - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = op_params.scale; const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -6207,7 +6270,7 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, } #endif // NEON for (; i < flat_size; ++i) { - const int32 val = input_data[i]; + const int32_t val = input_data[i]; const float result = static_cast(scale * (val - zero_point)); output_data[i] = result; } @@ -6218,7 +6281,7 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, const int8_t* input_data, const RuntimeShape& output_shape, float* output_data) { ruy::profiler::ScopeLabel label("Dequantize/Int8"); - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = op_params.scale; const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -6246,7 +6309,7 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, } #endif // NEON for (; i < flat_size; ++i) { - const int32 val = input_data[i]; + const int32_t val = input_data[i]; const float result = static_cast(scale * (val - zero_point)); output_data[i] = result; } @@ -6257,7 +6320,7 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, const int16_t* input_data, const RuntimeShape& output_shape, float* output_data) { ruy::profiler::ScopeLabel label("Dequantize/Int16"); - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = op_params.scale; const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -6283,7 +6346,7 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, } #endif // NEON for (; i < flat_size; ++i) { - const int32 val = input_data[i]; + const int32_t val = input_data[i]; const float result = static_cast(scale * (val - zero_point)); output_data[i] = result; } @@ -6311,11 +6374,11 @@ inline void AffineQuantize(const tflite::QuantizationParams& op_params, const RuntimeShape& output_shape, int8_t* output_data) { ruy::profiler::ScopeLabel label("Quantize/Int8"); - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = static_cast(op_params.scale); const int flat_size = MatchingFlatSize(input_shape, output_shape); - static constexpr int32 min_val = std::numeric_limits::min(); - static constexpr int32 max_val = std::numeric_limits::max(); + static constexpr int32_t min_val = std::numeric_limits::min(); + static constexpr int32_t max_val = std::numeric_limits::max(); int i = 0; #ifdef USE_NEON @@ -6354,9 +6417,9 @@ inline void AffineQuantize(const tflite::QuantizationParams& op_params, for (; i < flat_size; ++i) { const float val = input_data[i]; - const int32 unclamped = - static_cast(TfLiteRound(val / scale)) + zero_point; - const int32 clamped = std::min(std::max(unclamped, min_val), max_val); + const int32_t unclamped = + static_cast(TfLiteRound(val / scale)) + zero_point; + const int32_t clamped = std::min(std::max(unclamped, min_val), max_val); output_data[i] = clamped; } } @@ -6368,11 +6431,11 @@ inline void AffineQuantize(const tflite::QuantizationParams& op_params, const RuntimeShape& output_shape, uint8_t* output_data) { ruy::profiler::ScopeLabel label("Quantize/Uint8"); - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = static_cast(op_params.scale); const int flat_size = MatchingFlatSize(input_shape, output_shape); - static constexpr int32 min_val = std::numeric_limits::min(); - static constexpr int32 max_val = std::numeric_limits::max(); + static constexpr int32_t min_val = std::numeric_limits::min(); + static constexpr int32_t max_val = std::numeric_limits::max(); int i = 0; #ifdef USE_NEON @@ -6412,9 +6475,9 @@ inline void AffineQuantize(const tflite::QuantizationParams& op_params, for (; i < flat_size; ++i) { const float val = input_data[i]; - const int32 unclamped = - static_cast(TfLiteRound(val / scale)) + zero_point; - const int32 clamped = std::min(std::max(unclamped, min_val), max_val); + const int32_t unclamped = + static_cast(TfLiteRound(val / scale)) + zero_point; + const int32_t clamped = std::min(std::max(unclamped, min_val), max_val); output_data[i] = clamped; } } @@ -6426,11 +6489,11 @@ inline void AffineQuantize(const tflite::QuantizationParams& op_params, const RuntimeShape& output_shape, int16_t* output_data) { ruy::profiler::ScopeLabel label("Quantize/Int16"); - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = static_cast(op_params.scale); const int flat_size = MatchingFlatSize(input_shape, output_shape); - static constexpr int32 min_val = std::numeric_limits::min(); - static constexpr int32 max_val = std::numeric_limits::max(); + static constexpr int32_t min_val = std::numeric_limits::min(); + static constexpr int32_t max_val = std::numeric_limits::max(); int i = 0; #ifdef USE_NEON @@ -6468,9 +6531,9 @@ inline void AffineQuantize(const tflite::QuantizationParams& op_params, for (; i < flat_size; ++i) { const float val = input_data[i]; - const int32 unclamped = - static_cast(TfLiteRound(val / scale)) + zero_point; - const int32 clamped = std::min(std::max(unclamped, min_val), max_val); + const int32_t unclamped = + static_cast(TfLiteRound(val / scale)) + zero_point; + const int32_t clamped = std::min(std::max(unclamped, min_val), max_val); output_data[i] = clamped; } } @@ -6484,9 +6547,9 @@ inline int16x8x4_t SaturatingRounding( int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2, int16x8_t input_val_3, int input_left_shift, int input_multiplier) { // This performs what is expressed in the scalar code as - // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul( - // static_cast(input_val_centered * (1 << input_left_shift)), - // static_cast(input_multiplier)); + // const int16_t input_val_rescaled = SaturatingRoundingDoublingHighMul( + // static_cast(input_val_centered * (1 << input_left_shift)), + // static_cast(input_multiplier)); const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift); const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup); const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup); @@ -6623,15 +6686,17 @@ inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val, inline void Tanh16bitPrecision(const TanhParams& params, const RuntimeShape& input_shape, - const uint8* input_data, + const uint8_t* input_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { // Note that this is almost the exact same code as in Logistic(). ruy::profiler::ScopeLabel label("Tanh/Uint8"); - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int16 input_multiplier = static_cast(params.input_multiplier); - const int16 input_left_shift = static_cast(params.input_left_shift); + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int16_t input_multiplier = + static_cast(params.input_multiplier); + const int16_t input_left_shift = + static_cast(params.input_left_shift); const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; @@ -6647,7 +6712,7 @@ inline void Tanh16bitPrecision(const TanhParams& params, // Handle 32 values at a time for (; c <= size - 32; c += 32) { - // Read input uint8 values, cast to int16 and subtract input_zero_point + // Read input uint8_t values, cast to int16_t and subtract input_zero_point using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint; const int16x8x2_t input_val_centered_0_1 = Load16AndSubtractZeroPoint(input_data + c, input_zero_point); @@ -6684,7 +6749,7 @@ inline void Tanh16bitPrecision(const TanhParams& params, output_val_s16.val[3] = vaddq_s16(output_val_s16.val[3], output_zero_point_s16); - // Cast output values to uint8, saturating + // Cast output values to uint8_t, saturating uint8x16_t output_val_u8_0_1 = vcombine_u8( vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1])); uint8x16_t output_val_u8_2_3 = vcombine_u8( @@ -6697,32 +6762,32 @@ inline void Tanh16bitPrecision(const TanhParams& params, #endif // GEMMLOWP_NEON // Leftover loop: handle one value at a time with scalar code. for (; c < size; ++c) { - const uint8 input_val_u8 = input_data[c]; - const int16 input_val_centered = - static_cast(input_val_u8) - input_zero_point; - uint8 output_val; + const uint8_t input_val_u8 = input_data[c]; + const int16_t input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8_t output_val; if (input_val_centered < -input_range_radius) { output_val = 0; } else if (input_val_centered > input_range_radius) { output_val = 255; } else { using gemmlowp::SaturatingRoundingDoublingHighMul; - const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul( - static_cast(input_val_centered * (1 << input_left_shift)), - static_cast(input_multiplier)); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + const int16_t input_val_rescaled = SaturatingRoundingDoublingHighMul( + static_cast(input_val_centered * (1 << input_left_shift)), + static_cast(input_multiplier)); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); using gemmlowp::RoundingDivideByPOT; - int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8); + int16_t output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8); output_val_s16 += output_zero_point; if (output_val_s16 == 256) { output_val_s16 = 255; } TFLITE_DCHECK_GE(output_val_s16, 0); TFLITE_DCHECK_LE(output_val_s16, 255); - output_val = static_cast(output_val_s16); + output_val = static_cast(output_val_s16); } output_data[c] = output_val; } @@ -6730,15 +6795,17 @@ inline void Tanh16bitPrecision(const TanhParams& params, inline void Tanh16bitPrecision(const TanhParams& params, const RuntimeShape& input_shape, - const int8* input_data, + const int8_t* input_data, const RuntimeShape& output_shape, - int8* output_data) { + int8_t* output_data) { // Note that this is almost the exact same code as in Logistic(). ruy::profiler::ScopeLabel label("Tanh/Int8"); - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int16 input_multiplier = static_cast(params.input_multiplier); - const int16 input_left_shift = static_cast(params.input_left_shift); + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int16_t input_multiplier = + static_cast(params.input_multiplier); + const int16_t input_left_shift = + static_cast(params.input_left_shift); const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; @@ -6751,7 +6818,7 @@ inline void Tanh16bitPrecision(const TanhParams& params, // Handle 32 values at a time for (; c <= size - 32; c += 32) { - // Read input int8 values, cast to int16 and subtract input_zero_point + // Read input int8_t values, cast to int16_t and subtract input_zero_point using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint; const int16x8x2_t input_val_centered_0_1 = Load16AndSubtractZeroPoint(input_data + c, input_zero_point); @@ -6778,7 +6845,7 @@ inline void Tanh16bitPrecision(const TanhParams& params, int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled); - // Cast output values to uint8, saturating + // Cast output values to uint8_t, saturating int8x16_t output_val_s8_0_1 = vcombine_s8( vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1])); int8x16_t output_val_s8_2_3 = vcombine_s8( @@ -6791,31 +6858,31 @@ inline void Tanh16bitPrecision(const TanhParams& params, #endif // GEMMLOWP_NEON // Leftover loop: handle one value at a time with scalar code. for (; c < size; ++c) { - const int8 input_val_s8 = input_data[c]; - const int16 input_val_centered = - static_cast(input_val_s8) - input_zero_point; - int8 output_val; + const int8_t input_val_s8 = input_data[c]; + const int16_t input_val_centered = + static_cast(input_val_s8) - input_zero_point; + int8_t output_val; if (input_val_centered <= -input_range_radius) { output_val = -128; } else if (input_val_centered >= input_range_radius) { output_val = 127; } else { using gemmlowp::SaturatingRoundingDoublingHighMul; - const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul( - static_cast(input_val_centered * (1 << input_left_shift)), - static_cast(input_multiplier)); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + const int16_t input_val_rescaled = SaturatingRoundingDoublingHighMul( + static_cast(input_val_centered * (1 << input_left_shift)), + static_cast(input_multiplier)); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); using gemmlowp::RoundingDivideByPOT; - int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8); + int16_t output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8); if (output_val_s16 == 128) { output_val_s16 = 127; } TFLITE_DCHECK_GE(output_val_s16, -128); TFLITE_DCHECK_LE(output_val_s16, 127); - output_val = static_cast(output_val_s16); + output_val = static_cast(output_val_s16); } output_data[c] = output_val; } @@ -6823,14 +6890,15 @@ inline void Tanh16bitPrecision(const TanhParams& params, inline void Logistic16bitPrecision(const LogisticParams& params, const RuntimeShape& input_shape, - const uint8* input_data, + const uint8_t* input_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { ruy::profiler::ScopeLabel label("Logistic/Uint8"); - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int32 input_multiplier = params.input_multiplier; - const int16 input_left_shift = static_cast(params.input_left_shift); + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int32_t input_multiplier = params.input_multiplier; + const int16_t input_left_shift = + static_cast(params.input_left_shift); const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; @@ -6843,7 +6911,7 @@ inline void Logistic16bitPrecision(const LogisticParams& params, // Handle 32 values at a time for (; c <= size - 32; c += 32) { - // Read input uint8 values, cast to int16 and subtract input_zero_point + // Read input uint8_t values, cast to int16_t and subtract input_zero_point using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint; const int16x8x2_t input_val_centered_0_1 = Load16AndSubtractZeroPoint(input_data + c, input_zero_point); @@ -6870,7 +6938,7 @@ inline void Logistic16bitPrecision(const LogisticParams& params, int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled); - // Cast output values to uint8, saturating + // Cast output values to uint8_t, saturating uint8x16_t output_val_u8_0_1 = vcombine_u8( vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1])); uint8x16_t output_val_u8_2_3 = vcombine_u8( @@ -6883,31 +6951,31 @@ inline void Logistic16bitPrecision(const LogisticParams& params, #endif // GEMMLOWP_NEON // Leftover loop: handle one value at a time with scalar code. for (; c < size; ++c) { - const uint8 input_val_u8 = input_data[c]; - const int16 input_val_centered = - static_cast(input_val_u8) - input_zero_point; - uint8 output_val; + const uint8_t input_val_u8 = input_data[c]; + const int16_t input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8_t output_val; if (input_val_centered < -input_range_radius) { output_val = 0; } else if (input_val_centered > input_range_radius) { output_val = 255; } else { using gemmlowp::SaturatingRoundingDoublingHighMul; - const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul( - static_cast(input_val_centered * (1 << input_left_shift)), - static_cast(input_multiplier)); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + const int16_t input_val_rescaled = SaturatingRoundingDoublingHighMul( + static_cast(input_val_centered * (1 << input_left_shift)), + static_cast(input_multiplier)); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); using gemmlowp::RoundingDivideByPOT; - int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7); + int16_t output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7); if (output_val_s16 == 256) { output_val_s16 = 255; } TFLITE_DCHECK_GE(output_val_s16, 0); TFLITE_DCHECK_LE(output_val_s16, 255); - output_val = static_cast(output_val_s16); + output_val = static_cast(output_val_s16); } output_data[c] = output_val; } @@ -6915,18 +6983,19 @@ inline void Logistic16bitPrecision(const LogisticParams& params, inline void Logistic16bitPrecision(const LogisticParams& params, const RuntimeShape& input_shape, - const int8* input_data, + const int8_t* input_data, const RuntimeShape& output_shape, - int8* output_data) { + int8_t* output_data) { ruy::profiler::ScopeLabel label("Logistic/Int8"); - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int32 input_multiplier = params.input_multiplier; - const int16 input_left_shift = static_cast(params.input_left_shift); + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int32_t input_multiplier = params.input_multiplier; + const int16_t input_left_shift = + static_cast(params.input_left_shift); const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; - const int16 output_zero_point = 128; + const int16_t output_zero_point = 128; // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed. // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types. @@ -6937,7 +7006,7 @@ inline void Logistic16bitPrecision(const LogisticParams& params, // Handle 32 values at a time for (; c <= size - 32; c += 32) { - // Read input int8 values, cast to int16 and subtract input_zero_point + // Read input int8_t values, cast to int16_t and subtract input_zero_point using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint; const int16x8x2_t input_val_centered_0_1 = Load16AndSubtractZeroPoint(input_data + c, input_zero_point); @@ -6974,7 +7043,7 @@ inline void Logistic16bitPrecision(const LogisticParams& params, output_val_s16.val[3] = vsubq_s16(output_val_s16.val[3], output_zero_point_dup); - // Cast output values to int8, saturating + // Cast output values to int8_t, saturating int8x16_t output_val_s8_0_1 = vcombine_s8( vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1])); int8x16_t output_val_s8_2_3 = vcombine_s8( @@ -6987,32 +7056,32 @@ inline void Logistic16bitPrecision(const LogisticParams& params, #endif // GEMMLOWP_NEON // Leftover loop: handle one value at a time with scalar code. for (; c < size; ++c) { - const int8 input_val_s8 = input_data[c]; - const int16 input_val_centered = - static_cast(input_val_s8) - input_zero_point; - int8 output_val; + const int8_t input_val_s8 = input_data[c]; + const int16_t input_val_centered = + static_cast(input_val_s8) - input_zero_point; + int8_t output_val; if (input_val_centered < -input_range_radius) { output_val = -128; } else if (input_val_centered > input_range_radius) { output_val = 127; } else { using gemmlowp::SaturatingRoundingDoublingHighMul; - const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul( - static_cast(input_val_centered * (1 << input_left_shift)), - static_cast(input_multiplier)); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + const int16_t input_val_rescaled = SaturatingRoundingDoublingHighMul( + static_cast(input_val_centered * (1 << input_left_shift)), + static_cast(input_multiplier)); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); using gemmlowp::RoundingDivideByPOT; - int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7); + int16_t output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7); output_val_s16 -= output_zero_point; if (output_val_s16 == 128) { output_val_s16 = 127; } TFLITE_DCHECK_GE(output_val_s16, -128); TFLITE_DCHECK_LE(output_val_s16, 127); - output_val = static_cast(output_val_s16); + output_val = static_cast(output_val_s16); } output_data[c] = output_val; } @@ -7343,8 +7412,8 @@ void Transpose(const TransposeParams& unshrinked_params, // Assume input1 & input2 have the same scale & zero point. inline void MaximumElementwise(int size, const ArithmeticParams& params, - const int8* input1_data, const int8* input2_data, - int8* output_data) { + const int8_t* input1_data, + const int8_t* input2_data, int8_t* output_data) { ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit"); int i = 0; #ifdef USE_NEON @@ -7357,15 +7426,16 @@ inline void MaximumElementwise(int size, const ArithmeticParams& params, } #endif // USE_NEON for (; i < size; ++i) { - const int8 input1_val = input1_data[i]; - const int8 input2_val = input2_data[i]; + const int8_t input1_val = input1_data[i]; + const int8_t input2_val = input2_data[i]; output_data[i] = std::max(input1_val, input2_val); } } inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params, - int8 input1_data, const int8* input2_data, - int8* output_data) { + int8_t input1_data, + const int8_t* input2_data, + int8_t* output_data) { ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit"); int i = 0; @@ -7379,15 +7449,15 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params, } #endif // USE_NEON for (; i < size; ++i) { - const int8 input2_val = input2_data[i]; + const int8_t input2_val = input2_data[i]; output_data[i] = std::max(input1_data, input2_val); } } // Assume input1 & input2 have the same scale & zero point. inline void MinimumElementwise(int size, const ArithmeticParams& params, - const int8* input1_data, const int8* input2_data, - int8* output_data) { + const int8_t* input1_data, + const int8_t* input2_data, int8_t* output_data) { ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit"); int i = 0; #ifdef USE_NEON @@ -7400,15 +7470,16 @@ inline void MinimumElementwise(int size, const ArithmeticParams& params, } #endif // USE_NEON for (; i < size; ++i) { - const int8 input1_val = input1_data[i]; - const int8 input2_val = input2_data[i]; + const int8_t input1_val = input1_data[i]; + const int8_t input2_val = input2_data[i]; output_data[i] = std::min(input1_val, input2_val); } } inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params, - int8 input1_data, const int8* input2_data, - int8* output_data) { + int8_t input1_data, + const int8_t* input2_data, + int8_t* output_data) { ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit"); int i = 0; @@ -7422,7 +7493,7 @@ inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params, } #endif // USE_NEON for (; i < size; ++i) { - const int8 input2_val = input2_data[i]; + const int8_t input2_val = input2_data[i]; output_data[i] = std::min(input1_data, input2_val); } } @@ -7430,11 +7501,11 @@ inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params, template inline void BroadcastMaximumDispatch(const ArithmeticParams& params, const RuntimeShape& input1_shape, - const int8* input1_data, + const int8_t* input1_data, const RuntimeShape& input2_shape, - const int8* input2_data, + const int8_t* input2_data, const RuntimeShape& output_shape, - int8* output_data, Op op) { + int8_t* output_data, Op op) { if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) { return reference_ops::MaximumMinimumBroadcastSlow( input1_shape, input1_data, input2_shape, input2_data, output_shape, @@ -7449,11 +7520,11 @@ inline void BroadcastMaximumDispatch(const ArithmeticParams& params, template inline void BroadcastMinimumDispatch(const ArithmeticParams& params, const RuntimeShape& input1_shape, - const int8* input1_data, + const int8_t* input1_data, const RuntimeShape& input2_shape, - const int8* input2_data, + const int8_t* input2_data, const RuntimeShape& output_shape, - int8* output_data, Op op) { + int8_t* output_data, Op op) { if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) { return reference_ops::MaximumMinimumBroadcastSlow( input1_shape, input1_data, input2_shape, input2_data, output_shape, @@ -7979,7 +8050,7 @@ inline TfLiteStatus Conv3D( ruy::profiler::ScopeLabel label("Conv3D"); // NB: the float 0.0f value is represented by all zero bytes. - const uint8 float_zero_byte = 0x00; + const uint8_t float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; const int filter_width = filter_shape.Dims(2); diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 469abbdc7b3c17..7bf649b0065fd5 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -151,11 +151,11 @@ inline void ReluX(const tflite::ReluParams& params, ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { - const int32 val = static_cast(input_data[i]); - int32 clamped = params.output_offset + - MultiplyByQuantizedMultiplier(val - params.input_offset, - params.output_multiplier, - params.output_shift); + const int32_t val = static_cast(input_data[i]); + int32_t clamped = params.output_offset + + MultiplyByQuantizedMultiplier(val - params.input_offset, + params.output_multiplier, + params.output_shift); clamped = std::max(params.quantized_activation_min, clamped); clamped = std::min(params.quantized_activation_max, clamped); output_data[i] = static_cast(clamped); @@ -185,11 +185,11 @@ inline void ReluX(const tflite::ActivationParams& params, // generate max(D1, D2) nested for loops. inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, const RuntimeShape& unswitched_input1_shape, - const uint8* unswitched_input1_data, + const uint8_t* unswitched_input1_data, const RuntimeShape& unswitched_input2_shape, - const uint8* unswitched_input2_data, + const uint8_t* unswitched_input2_data, const RuntimeShape& output_shape, - uint8* output_data) { + uint8_t* output_data) { ArithmeticParams switched_params = unswitched_params; switched_params.input1_offset = unswitched_params.input2_offset; switched_params.input2_offset = unswitched_params.input1_offset; @@ -200,25 +200,25 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, const ArithmeticParams& params = use_unswitched ? unswitched_params : switched_params; - const uint8* input1_data = + const uint8_t* input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data; - const uint8* input2_data = + const uint8_t* input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data; // Fivefold nested loops. The second input resets its position for each // iteration of the second loop. The first input resets its position at the // beginning of the fourth loop. The innermost loop is an elementwise Mul of // sections of the arrays. - uint8* output_data_ptr = output_data; - const uint8* input1_data_ptr = input1_data; - const uint8* input2_data_reset = input2_data; + uint8_t* output_data_ptr = output_data; + const uint8_t* input1_data_ptr = input1_data; + const uint8_t* input2_data_reset = input2_data; int y0 = params.broadcast_shape[0]; int y1 = params.broadcast_shape[1]; int y2 = params.broadcast_shape[2]; int y3 = params.broadcast_shape[3]; int y4 = params.broadcast_shape[4]; for (int i0 = 0; i0 < y0; ++i0) { - const uint8* input2_data_ptr; + const uint8_t* input2_data_ptr; for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { @@ -236,9 +236,9 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, } inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const int16* input1_data, - const RuntimeShape& input2_shape, const int16* input2_data, - const RuntimeShape& output_shape, int16* output_data) { + const RuntimeShape& input1_shape, const int16_t* input1_data, + const RuntimeShape& input2_shape, const int16_t* input2_data, + const RuntimeShape& output_shape, int16_t* output_data) { ruy::profiler::ScopeLabel label("Mul/Int16"); const int flat_size = @@ -255,13 +255,13 @@ inline void Mul(const ArithmeticParams& params, } inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const int16* input1_data, - const RuntimeShape& input2_shape, const int16* input2_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input1_shape, const int16_t* input1_data, + const RuntimeShape& input2_shape, const int16_t* input2_data, + const RuntimeShape& output_shape, uint8_t* output_data) { ruy::profiler::ScopeLabel label("Mul/Int16Uint8"); - int32 output_offset = params.output_offset; - int32 output_activation_min = params.quantized_activation_min; - int32 output_activation_max = params.quantized_activation_max; + int32_t output_offset = params.output_offset; + int32_t output_activation_min = params.quantized_activation_min; + int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int flat_size = @@ -273,12 +273,12 @@ inline void Mul(const ArithmeticParams& params, F0 unclamped_result = F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); - int16 rescaled_result = + int16_t rescaled_result = gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8); - int16 clamped_result = - std::min(output_activation_max - output_offset, rescaled_result); - clamped_result = - std::max(output_activation_min - output_offset, clamped_result); + int16_t clamped_result = std::min( + output_activation_max - output_offset, rescaled_result); + clamped_result = std::max(output_activation_min - output_offset, + clamped_result); output_data[i] = output_offset + clamped_result; } } @@ -291,14 +291,15 @@ inline void Sub16(const ArithmeticParams& params, const int input1_shift = params.input1_shift; const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); - const int16 output_activation_min = params.quantized_activation_min; - const int16 output_activation_max = params.quantized_activation_max; + const int16_t output_activation_min = params.quantized_activation_min; + const int16_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); TFLITE_DCHECK_LE(input1_shift, 0); TFLITE_DCHECK_LE(params.input2_shift, 0); - const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; - const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; + const int16_t* not_shift_input = + input1_shift == 0 ? input1_data : input2_data; + const int16_t* shift_input = input1_shift == 0 ? input2_data : input1_data; const int input_right_shift = input1_shift == 0 ? -params.input2_shift : -input1_shift; @@ -310,8 +311,8 @@ inline void Sub16(const ArithmeticParams& params, F0 scaled_input = F0::FromRaw( gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = SaturatingSub(input_ready_scaled, scaled_input); - const int16 raw_output = result.raw(); - const int16 clamped_output = std::min( + const int16_t raw_output = result.raw(); + const int16_t clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = clamped_output; } @@ -323,8 +324,8 @@ inline void Sub16(const ArithmeticParams& params, F0 scaled_input = F0::FromRaw( gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = SaturatingSub(scaled_input, input_ready_scaled); - const int16 raw_output = result.raw(); - const int16 clamped_output = std::min( + const int16_t raw_output = result.raw(); + const int16_t clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = clamped_output; } @@ -395,15 +396,15 @@ void Unpack(const UnpackParams& params, const RuntimeShape& input_shape, template void PackWithScaling(const PackParams& params, const RuntimeShape* const* input_shapes, - const uint8* const* input_data, - const RuntimeShape& output_shape, uint8* output_data) { + const uint8_t* const* input_data, + const RuntimeShape& output_shape, uint8_t* output_data) { ruy::profiler::ScopeLabel label("PackWithScaling"); const int dimensions = output_shape.DimensionsCount(); int axis = params.axis; - const int32* input_zeropoint = params.input_zeropoint; + const int32_t* input_zeropoint = params.input_zeropoint; const float* input_scale = params.input_scale; int inputs_count = params.inputs_count; - const int32 output_zeropoint = params.output_zeropoint; + const int32_t output_zeropoint = params.output_zeropoint; const float output_scale = params.output_scale; int outer_size = 1; @@ -599,7 +600,7 @@ inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape, // Implements GatherNd. // Returns an error if any of the indices_data would cause an out of bounds // memory read. -template +template inline TfLiteStatus GatherNd(const RuntimeShape& params_shape, const ParamsT* params_data, const RuntimeShape& indices_shape, @@ -627,7 +628,7 @@ inline TfLiteStatus GatherNd(const RuntimeShape& params_shape, // Implements GatherNd on strings. // Returns an error if any of the indices_data would cause an out of bounds // memory read. -template +template inline TfLiteStatus GatherNdString(const RuntimeShape& params_shape, const TfLiteTensor* params_data, const RuntimeShape& indices_shape, diff --git a/tensorflow/lite/kernels/mul_test.cc b/tensorflow/lite/kernels/mul_test.cc index 34b484a4ca9c2c..f5f0d40da261f7 100644 --- a/tensorflow/lite/kernels/mul_test.cc +++ b/tensorflow/lite/kernels/mul_test.cc @@ -541,6 +541,64 @@ TEST_P(MulOpTest, Int32VariousInputShapes) { } } +// Neon intrinsics are only dispatched when tensor has at least 16 elements. +TEST_P(MulOpTest, Int32LargeInputShapeNoActivation) { + bool constant_tensors = GetParam(); + if (SingleOpModel::GetForceUseNnapi() && constant_tensors) { + // NNAPI does not support graphs with all constant inputs. + return; + } + const std::vector test_shape = {4, 4, 4, 4}; + constexpr int kFlatSize = 4 * 4 * 4 * 4; + + std::vector lhs_data(kFlatSize); + std::iota(lhs_data.begin(), lhs_data.end(), 0); + + std::vector rhs_data(kFlatSize); + std::iota(rhs_data.begin(), rhs_data.end(), 0); + + IntegerMulOpModel m( + {TensorType_INT32, test_shape}, {TensorType_INT32, test_shape}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE, lhs_data, rhs_data, + constant_tensors); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + const std::vector output = m.GetOutput(); + ASSERT_EQ(output.size(), kFlatSize); + for (int i = 0; i < kFlatSize; ++i) { + EXPECT_EQ(output[i], i * i); + } +} + +// Neon intrinsics are only dispatched when tensor has at least 16 elements. +TEST_P(MulOpTest, Int32LargeInputShapeRELU6) { + bool constant_tensors = GetParam(); + if (SingleOpModel::GetForceUseNnapi() && constant_tensors) { + // NNAPI does not support graphs with all constant inputs. + return; + } + const std::vector test_shape = {4, 4, 4, 4}; + constexpr int kFlatSize = 4 * 4 * 4 * 4; + + std::vector lhs_data(kFlatSize); + std::iota(lhs_data.begin(), lhs_data.end(), 0); + + std::vector rhs_data(kFlatSize); + std::iota(rhs_data.begin(), rhs_data.end(), 0); + + IntegerMulOpModel m( + {TensorType_INT32, test_shape}, {TensorType_INT32, test_shape}, + {TensorType_INT32, {}}, ActivationFunctionType_RELU6, lhs_data, rhs_data, + constant_tensors); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + const std::vector output = m.GetOutput(); + ASSERT_EQ(output.size(), kFlatSize); + for (int i = 0; i < kFlatSize; ++i) { + EXPECT_EQ(output[i], std::min(i * i, 6)); + } +} + TEST_P(MulOpTest, Int32WithBroadcast) { bool constant_tensors = GetParam(); if (SingleOpModel::GetForceUseNnapi() && constant_tensors) { diff --git a/tensorflow/lite/kernels/register.h b/tensorflow/lite/kernels/register.h index 6721dc69a328bd..e444accec6511e 100644 --- a/tensorflow/lite/kernels/register.h +++ b/tensorflow/lite/kernels/register.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_REGISTER_H_ #define TENSORFLOW_LITE_KERNELS_REGISTER_H_ -#include "tensorflow/lite/core/kernels/register.h" +/// For documentation, see third_party/tensorflow/lite/core/kernels/register.h + +#include "tensorflow/lite/core/kernels/register.h" // IWYU pragma: export namespace tflite { namespace ops { diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 7db39695c13de6..7cf5f3710de2c5 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -194,6 +194,7 @@ TfLiteRegistration* Register_STABLEHLO_MULTIPLY(); TfLiteRegistration* Register_STABLEHLO_REDUCE_WINDOW(); TfLiteRegistration* Register_STABLEHLO_MAXIMUM(); TfLiteRegistration* Register_STABLEHLO_MINIMUM(); +TfLiteRegistration* Register_STABLEHLO_PAD(); namespace { @@ -491,7 +492,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 5); AddBuiltin(BuiltinOperator_WHERE, Register_WHERE(), /* min_version = */ 1, /* max_version = */ 2); AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); @@ -558,6 +559,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_STABLEHLO_REDUCE_WINDOW, Register_STABLEHLO_REDUCE_WINDOW()); AddBuiltin(BuiltinOperator_STABLEHLO_GATHER, Register_STABLEHLO_GATHER()); + AddBuiltin(BuiltinOperator_STABLEHLO_PAD, Register_STABLEHLO_PAD()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY_REF()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/lite/kernels/split.cc b/tensorflow/lite/kernels/split.cc index 1491f4bbb98823..83add14be0173e 100644 --- a/tensorflow/lite/kernels/split.cc +++ b/tensorflow/lite/kernels/split.cc @@ -87,7 +87,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || - input_type == kTfLiteInt32); + input_type == kTfLiteInt32 || input_type == kTfLiteInt64); for (int i = 0; i < NumOutputs(node); ++i) { TfLiteTensor* tensor; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor)); @@ -158,6 +158,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(int32_t); break; } + case kTfLiteInt64: { + TF_LITE_SPLIT(int64_t); + break; + } default: TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.", TfLiteTypeGetName(op_context.input->type)); diff --git a/tensorflow/lite/kernels/stablehlo_pad.cc b/tensorflow/lite/kernels/stablehlo_pad.cc new file mode 100644 index 00000000000000..13f6b74eae8906 --- /dev/null +++ b/tensorflow/lite/kernels/stablehlo_pad.cc @@ -0,0 +1,291 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + // +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/c/builtin_op_data.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace stablehlo_pad { +namespace { + +static constexpr int kMaxDims = 6; + +// Fills a buffer with the given data. +// +// WARNING: This expects buffer_bytes to be a multiple of data_bytes. +void FillBuffer(char* buffer, int64_t buffer_bytes, const char* data, + int64_t data_bytes) { + if (buffer_bytes == 0) { + return; + } + assert(buffer_bytes % data_bytes == 0); + std::memcpy(buffer, data, data_bytes); + buffer_bytes -= data_bytes; + while (buffer_bytes) { + const int64_t bytes = std::min(buffer_bytes, data_bytes); + std::memcpy(buffer + data_bytes, buffer, bytes); + buffer_bytes -= bytes; + data_bytes += bytes; + } +} + +// Recursive implementation of a strided copy of a tensor. +void StridedCopy(const int rank, const char* input, const int64_t* input_shape, + const int64_t* input_strides, char* output, + const int64_t* output_strides, const int64_t element_size, + const int depth) { + if (depth + 1 == rank) { + for (int64_t i = 0; i < input_shape[depth]; ++i) { + std::memcpy(output, input, element_size); + input += input_strides[depth]; + output += output_strides[depth]; + } + } else { + for (int64_t i = 0; i < input_shape[depth]; ++i) { + StridedCopy(rank, input, input_shape, input_strides, output, + output_strides, element_size, depth + 1); + input += input_strides[depth]; + output += output_strides[depth]; + } + } +} + +// Holds the main implementation of the Pad operation. +// +// The StableHLO pad operation can add interior padding and edge padding to a +// tensor. The edge padding may be negative in which case it is considered as a +// cropping specification. +// +// This is implemented as a strided copy where: +// +// - interior padding affects the output strides. +// - positive edge padding affects the output shape, strides and initial offset. +// - negative edge padding affects the input shape and initial offset as well as +// the output initial offset. +// +// See https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad for more +// information. +class PadData { + public: + enum { kInput, kPaddingValue, kInputTensorCount }; + enum { kOutput, kOutputTensorCount }; + + explicit PadData(const TfLiteStablehloPadParams& params) { + std::memcpy( + edge_pad_low_, params.edge_padding_low, + TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT * sizeof(int64_t)); + std::memcpy( + edge_pad_high_, params.edge_padding_high, + TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT * sizeof(int64_t)); + std::memcpy( + interior_pad_, params.interior_padding, + TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT * sizeof(int64_t)); + } + + // Computes the shapes and strides that are needed for the final strided copy. + void Setup(const int* dims, const int rank, const int64_t element_size) { + rank_ = rank; + element_size_ = element_size; + input_offset_ = 0; + output_offset_ = 0; + output_size_ = 0; + + // Compute the output shape. + for (int i = 0; i < rank; ++i) { + output_shape_[i] = (dims[i] - 1) * (interior_pad_[i] + 1) + 1 + + edge_pad_low_[i] + edge_pad_high_[i]; + } + if (std::any_of(output_shape_, output_shape_ + rank, + [](auto s) { return s <= 0; })) { + std::memset(input_shape_, 0, sizeof(input_shape_)); + std::memset(output_shape_, 0, sizeof(output_shape_)); + output_size_ = 0; + return; + } + // Compute the output size for each dimension. + // + // This is different from the output strides because of the interior + // padding: the output strides take it into account to "jump" over the + // interior padding elements. + output_dimension_sizes_[rank - 1] = element_size; + for (int i = rank - 2; i >= 0; --i) { + output_dimension_sizes_[i] = + output_shape_[i + 1] * output_dimension_sizes_[i + 1]; + } + // Compute the output stride for each dimension. + // + // This is the stride between two elements that are copied from the input + // tensor (i.e. not generated by interior padding). + output_strides_[rank - 1] = element_size * (interior_pad_[rank - 1] + 1); + for (int i = rank - 2; i >= 0; --i) { + output_strides_[i] = output_dimension_sizes_[i] * (interior_pad_[i] + 1); + } + // Compute the output offset from the eventual pads. + for (int i = 0; i < rank; ++i) { + output_offset_ += + std::max(edge_pad_low_[i], 0) * output_dimension_sizes_[i]; + } + // Compute the final output size. + output_size_ = std::accumulate(output_shape_, output_shape_ + rank, + element_size, std::multiplies<>()); + // Compute input strides. + input_strides_[rank - 1] = element_size; + for (int i = rank - 1; i >= 1; --i) { + input_strides_[i - 1] = dims[i] * input_strides_[i]; + } + // Helper that computes the division between a negative num and a positive + // denum, rounding away from 0, or returns 0 if num is positive. + auto DivNegRoundAwayOrZero = [](int64_t num, int64_t denum) -> int64_t { + assert(denum > 0); + return num < 0 ? (num - denum + 1) / denum : 0; + }; + // Compute the input bounds from the eventual crops. + // + // If negative padding is applied, we can treat this as copying a subtensor + // of the input. We modify the input shape in place as we don't use it for + // anything else. + for (int i = 0; i < rank; ++i) { + input_shape_[i] = + dims[i] + + DivNegRoundAwayOrZero(edge_pad_low_[i], interior_pad_[i] + 1) + + DivNegRoundAwayOrZero(edge_pad_high_[i], interior_pad_[i] + 1); + } + // Compute the input offset from the eventual crops. + // + // When computing the subtensor from the negative padding, we need to find + // out the offset to its first element in addition to its shape (see + // previous comment). + // + // Cropping also means that the interior padding can become edge padding so + // we also need to update the output offset: + // + // > `1 0 0 0 2 0 0 0 3` cropped by 1 low element becomes `0 0 0 2 0 0 0 3` + // > which effectlvely means pad `2 3` with an interior padding of 3 and a + // > low edge padding of 3. + for (int i = 0; i < rank; ++i) { + input_offset_ -= + DivNegRoundAwayOrZero(edge_pad_low_[i], interior_pad_[i] + 1) * + input_strides_[i]; + if (edge_pad_low_[i] < 0) { + int64_t tmp_offset = ((interior_pad_[i] + 1 + edge_pad_low_[i]) % + (interior_pad_[i] + 1)); + if (tmp_offset < 0) { + tmp_offset += interior_pad_[i] + 1; + } + output_offset_ += tmp_offset * output_dimension_sizes_[i]; + } + } + } + + void Apply(const char* input, const char* padding_value, char* output) const { + // Fill the output tensor with the padding value. + FillBuffer(output, output_size_, padding_value, element_size_); + StridedCopy(rank_, input + input_offset_, input_shape_, input_strides_, + output + output_offset_, output_strides_, element_size_, + /*depth=*/0); + } + + TfLiteIntArray* BuildOuputTensorDims() const { + TfLiteIntArray* dims = TfLiteIntArrayCreate(rank_); + for (int64_t i = 0; i < rank_; ++i) { + dims->data[i] = output_shape_[i]; + } + return dims; + } + + private: + int64_t edge_pad_low_[kMaxDims]; + int64_t edge_pad_high_[kMaxDims]; + int64_t interior_pad_[kMaxDims]; + int64_t rank_ = 0; + int64_t element_size_ = 0; + int64_t input_shape_[kMaxDims]; + int64_t output_shape_[kMaxDims]; + int64_t input_strides_[kMaxDims]; + int64_t output_strides_[kMaxDims]; + int64_t output_dimension_sizes_[kMaxDims]; + int64_t input_offset_ = 0; + int64_t output_offset_ = 0; + int64_t output_size_ = 0; +}; + +void* Init(TfLiteContext* context, const char* options, size_t options_len) { + return new PadData( + *reinterpret_cast(options)); +} + +void Free(TfLiteContext* context, void* node_data) { + delete reinterpret_cast(node_data); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Input checks. + const TfLiteTensor* input_tensor = GetInput(context, node, PadData::kInput); + const TfLiteTensor* padding_value_tensor = + GetInput(context, node, PadData::kPaddingValue); + TF_LITE_ENSURE(context, input_tensor->type == padding_value_tensor->type); + // PadData computations. + size_t element_size; + TF_LITE_ENSURE(context, GetSizeOfType(context, input_tensor->type, + &element_size) == kTfLiteOk); + PadData& pad_data = *reinterpret_cast(node->user_data); + pad_data.Setup(input_tensor->dims->data, input_tensor->dims->size, + element_size); + // Output tensor setup. + TfLiteTensor* output_tensor = GetOutput(context, node, PadData::kOutput); + TF_LITE_ENSURE(context, input_tensor->type == output_tensor->type); + context->ResizeTensor(context, output_tensor, + pad_data.BuildOuputTensorDims()); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, PadData::kInput); + const TfLiteTensor* padding_value_tensor = + GetInput(context, node, PadData::kPaddingValue); + TfLiteTensor* output_tensor = GetOutput(context, node, PadData::kOutput); + // Pad using PadData + PadData& pad_data = *reinterpret_cast(node->user_data); + pad_data.Apply(input_tensor->data.raw_const, + padding_value_tensor->data.raw_const, output_tensor->data.raw); + return kTfLiteOk; +} + +} // namespace +} // namespace stablehlo_pad + +TfLiteRegistration* Register_STABLEHLO_PAD() { + static TfLiteRegistration r = {/*.init=*/stablehlo_pad::Init, + /*.free=*/stablehlo_pad::Free, + /*.prepare=*/stablehlo_pad::Prepare, + /*.invoke=*/stablehlo_pad::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/stablehlo_pad_test.cc b/tensorflow/lite/kernels/stablehlo_pad_test.cc new file mode 100644 index 00000000000000..f7a3aede43d40e --- /dev/null +++ b/tensorflow/lite/kernels/stablehlo_pad_test.cc @@ -0,0 +1,471 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + // +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// #include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/algorithm/container.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace stablehlo_pad { +namespace { + +using testing::ElementsAre; +using testing::ElementsAreArray; +using testing::HasSubstr; + +template +class StablehloPadModel : public SingleOpModel { + public: + static constexpr TensorType kTensorType = GetTensorType(); + + void SetEdgePadding(std::vector low, std::vector high) { + edge_padding_low_ = std::move(low); + edge_padding_high_ = std::move(high); + } + + const std::vector& GetEdgePaddingLow() const { + return edge_padding_low_; + } + + const std::vector& GetEdgePaddingHigh() const { + return edge_padding_high_; + } + + void SetInteriorPadding(std::vector padding) { + interior_padding_ = std::move(padding); + } + + const std::vector& GetInteriorPadding() const { + return interior_padding_; + } + + void SetInput(std::vector shape) { + input_.shape = shape; + input_.data.resize(absl::c_accumulate(shape, 1, std::multiplies<>())); + absl::c_iota(input_.data, static_cast(1)); + } + + void SetInput(std::vector shape, std::vector data) { + input_.shape = shape; + input_.data = data; + } + + void SetInput(absl::Span shape, absl::BitGenRef bitgen, T min, + T max) { + input_.shape.assign(shape.begin(), shape.end()); + input_.data.resize(absl::c_accumulate(shape, 1, std::multiplies<>())); + absl::c_generate(input_.data, [&] { + return absl::Uniform(absl::IntervalClosed, bitgen, min, max); + }); + } + + const reduce_window::reference::Tensor& GetInput() const { return input_; } + + void SetPaddingValue(const T& v) { padding_value_ = v; } + + T GetPaddingValue() const { return padding_value_; } + + absl::Span GetOutputData() { + return absl::Span(interpreter_->typed_tensor(output_tensor_id_), + GetTensorSize(output_tensor_id_)); + } + + absl::Span GetOutputShape() { + const TfLiteIntArray& shape = + *(interpreter_->tensor(output_tensor_id_)->dims); + return absl::Span(shape.data, shape.size); + } + + absl::Status CheckPreconditions() { + const size_t rank = input_.shape.size(); + if (rank == 0) { + return absl::FailedPreconditionError("Input rank is 0."); + } + if (edge_padding_low_.empty()) { + edge_padding_low_ = std::vector(rank, 0); + } else if (edge_padding_low_.size() != rank) { + return absl::FailedPreconditionError( + "Low edge padding does not have the right size."); + } + if (edge_padding_high_.empty()) { + edge_padding_high_ = std::vector(rank, 0); + } else if (edge_padding_high_.size() != rank) { + return absl::FailedPreconditionError( + "High edge padding does not have the right size."); + } + if (interior_padding_.empty()) { + interior_padding_ = std::vector(rank, 0); + } else if (interior_padding_.size() != rank) { + return absl::FailedPreconditionError( + "Interior padding does not have the right size."); + } + return absl::OkStatus(); + } + + absl::Status Build() { + if (absl::Status status = CheckPreconditions(); !status.ok()) { + return status; + } + input_tensor_id_ = + AddInput({kTensorType, + std::vector(input_.shape.begin(), input_.shape.end())}); + padding_value_tensor_id_ = + AddConstInput(kTensorType, /*data=*/{padding_value_}, /*shape=*/{1}); + output_tensor_id_ = AddOutput(kTensorType); + + SetBuiltinOp(BuiltinOperator_STABLEHLO_PAD, + BuiltinOptions2_StablehloPadOptions, + CreateStablehloPadOptions( + builder_, builder_.CreateVector(edge_padding_low_), + builder_.CreateVector(edge_padding_high_), + builder_.CreateVector(interior_padding_)) + .Union()); + BuildInterpreter( + /*input_shapes=*/{std::vector(input_.shape.begin(), + input_.shape.end())}, + /*num_threads=*/-1, /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/true, /*allocate_and_delegate=*/false, + /*use_simple_allocator=*/false); + AllocateAndDelegate(/*apply_delegate=*/true); + PopulateTensor(input_tensor_id_, input_.data); + return absl::OkStatus(); + } + + absl::Status BuildAndInvoke() { + if (absl::Status status = Build(); !status.ok()) { + return status; + } + if (TfLiteStatus status = Invoke(); status != kTfLiteOk) { + const std::string msg = + absl::StrFormat("Invoke failed with status %d.", status); + return absl::InternalError(msg); + } + return absl::OkStatus(); + } + + friend std::ostream& operator<<(std::ostream& os, + const StablehloPadModel& model) { + auto print_vec = [&os](const auto& vec) { + os << "["; + if (!vec.empty()) { + auto it = vec.begin(); + os << +*(it++); + for (; it != vec.end(); ++it) { + os << ", " << +*it; + } + } + os << "]"; + }; + os << " edge_padding_low: "; + print_vec(model.GetEdgePaddingLow()); + os << "\n edge_padding_high: "; + print_vec(model.GetEdgePaddingHigh()); + os << "\n interior_padding: "; + print_vec(model.GetInteriorPadding()); + os << "\n padding_value: " << +model.GetPaddingValue(); + os << "\n input shape: "; + print_vec(model.GetInput().shape); + return os; + } + + private: + std::vector edge_padding_low_; + std::vector edge_padding_high_; + std::vector interior_padding_; + reduce_window::reference::Tensor input_; + T padding_value_ = 0; + + int input_tensor_id_; + int padding_value_tensor_id_; + int output_tensor_id_; +}; + +template +absl::StatusOr> ComputeReference( + StablehloPadModel& model) { + if (absl::Status status = model.CheckPreconditions(); !status.ok()) { + return status; + } + std::vector dilations, padding; + for (size_t i = 0; i < model.GetInput().shape.size(); ++i) { + padding.push_back(model.GetEdgePaddingLow()[i]); + padding.push_back(model.GetEdgePaddingHigh()[i]); + dilations.push_back(model.GetInteriorPadding()[i] + 1); + } + + auto dilated_tensor = reduce_window::reference::Dilate( + model.GetInput(), dilations, model.GetPaddingValue()); + auto padded_tensor = reduce_window::reference::Pad(dilated_tensor, padding, + model.GetPaddingValue()); + return reduce_window::reference::Crop(padded_tensor, padding); +} + +TEST(StablehloPadModelTest, DefaultModelFails) { + StablehloPadModel model; + const auto expected_status = ComputeReference(model); + EXPECT_FALSE(expected_status.ok()); + EXPECT_EQ(expected_status.status().code(), + absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(expected_status.status().message(), + HasSubstr("Input rank is 0.")); +} + +TEST(StablehloPadModelTest, DefaultModelReturnsIdentity) { + StablehloPadModel model; + model.SetInput({3, 1}); + EXPECT_THAT(model.GetInput().shape, ElementsAre(3, 1)); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + EXPECT_THAT(expected_status.value().data, + ElementsAreArray(model.GetInput().data)); +} + +TEST(StablehloPadModelTest, WrongEdgePaddingSizeIsAnError) { + StablehloPadModel model; + model.SetInput({3, 1}); + model.SetEdgePadding(/*low=*/{3, 4, 5}, /*high=*/{6, 7}); + { + const auto expected_status = ComputeReference(model); + EXPECT_FALSE(expected_status.ok()); + EXPECT_EQ(expected_status.status().code(), + absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(expected_status.status().message(), + HasSubstr("Low edge padding does not have the right size.")); + } + model.SetEdgePadding(/*low=*/{3, 4}, /*high=*/{5, 6, 7}); + { + const auto expected_status = ComputeReference(model); + EXPECT_FALSE(expected_status.ok()); + EXPECT_EQ(expected_status.status().code(), + absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(expected_status.status().message(), + HasSubstr("High edge padding does not have the right size.")); + } +} + +TEST(StablehloPadModelTest, WrongInteriorPaddingSizeIsAnError) { + StablehloPadModel model; + model.SetInput({3, 1}); + model.SetInteriorPadding({3, 4, 5}); + const auto expected_status = ComputeReference(model); + EXPECT_FALSE(expected_status.ok()); + EXPECT_EQ(expected_status.status().code(), + absl::StatusCode::kFailedPrecondition); + EXPECT_THAT(expected_status.status().message(), + HasSubstr("Interior padding does not have the right size.")); +} + +TEST(StablehloPadTest, IdentityParams) { + StablehloPadModel model; + model.SetInput({3, 3}); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(model.GetInput().shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(model.GetInput().data)); +} + +TEST(StablehloPadTest, InteriorPad) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetInteriorPadding({1, 2}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, LowPad) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({1, 1}, {0, 0}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, HighPad) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({0, 0}, {1, 1}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, AllPad) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({1, 1}, {1, 1}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, LowCrop) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({-1, -1}, {0, 0}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, HighCrop) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({0, 0}, {-1, -1}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, AllCrop) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({-1, -1}, {-1, -1}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, PadCrop) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({1, -1}, {1, -1}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, InteriorEdgePadding) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({-1, -4}, {0, 0}); + model.SetInteriorPadding({1, 2}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +TEST(StablehloPadTest, CallPrepareTwiceDoesNotFail) { + StablehloPadModel model; + model.SetInput({3, 3}); + model.SetEdgePadding({-1, -4}, {0, 0}); + model.SetInteriorPadding({1, 2}); + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + // Applying delegates forces Prepare to be called twice. + model.SetApplyDefaultDelegates(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)); + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)); +} + +// Returns a vector of given size with elements in the range [min, max]. +template +std::vector RandomVector(absl::BitGen& bitgen, size_t size, T min, T max) { + std::vector vec(size); + for (T& v : vec) { + v = absl::Uniform(absl::IntervalClosed, bitgen, min, max); + } + return vec; +} + +template +class StablehloPadFuzzyTest : public testing::Test {}; + +using TestList = + testing::Types; +TYPED_TEST_SUITE(StablehloPadFuzzyTest, TestList); + +TYPED_TEST(StablehloPadFuzzyTest, FuzzyTest) { + absl::BitGen bitgen; + + for (size_t iteration = 0; iteration < 10000; ++iteration) { + const int rank = absl::Uniform(absl::IntervalClosed, bitgen, 1, 2); + + StablehloPadModel model; + model.SetInput( + /*shape=*/RandomVector(bitgen, rank, /*min=*/1, /*max=*/3), + bitgen, /*min=*/-5, /*max=*/5); + model.SetInteriorPadding( + RandomVector(bitgen, rank, /*min=*/0, /*max=*/2)); + model.SetEdgePadding( + RandomVector(bitgen, rank, /*min=*/-5, /*max=*/5), + RandomVector(bitgen, rank, /*min=*/-5, /*max=*/5)); + model.SetPaddingValue( + absl::Uniform(absl::IntervalClosed, bitgen, -127, 127)); + + const auto expected_status = ComputeReference(model); + ASSERT_TRUE(expected_status.ok()); + const auto& expected = expected_status.value(); + ASSERT_TRUE(model.BuildAndInvoke().ok()); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(expected.shape)) + << model; + EXPECT_THAT(model.GetOutputData(), ElementsAreArray(expected.data)) + << model; + } +} + +} // namespace +} // namespace stablehlo_pad +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window.cc b/tensorflow/lite/kernels/stablehlo_reduce_window.cc index 32bf358239bc85..78385506c8d768 100644 --- a/tensorflow/lite/kernels/stablehlo_reduce_window.cc +++ b/tensorflow/lite/kernels/stablehlo_reduce_window.cc @@ -202,7 +202,7 @@ void Dilate(const DilateData& ctx, const char* input, const char* init_value, // Fill the output tensor with the padding value. { std::memcpy(output, init_value, ctx.init_element_size); - int64_t remaining_bytes = ctx.output_size; + int64_t remaining_bytes = ctx.output_size - ctx.init_element_size; int64_t copied_bytes = ctx.init_element_size; while (remaining_bytes) { int64_t bytes = std::min(remaining_bytes, copied_bytes); diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc b/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc index a26c286ea350e9..fa95ac51075738 100644 --- a/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc +++ b/tensorflow/lite/kernels/stablehlo_reduce_window_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -688,10 +689,17 @@ std::vector RandomVector(absl::BitGen& bitgen, size_t size, T min, T max) { } struct Body { - static Body GetRandomSupported(absl::BitGen& bitgen) { - return Body{/*.body=*/static_cast(absl::Uniform( + static Body GetRandomSupported(absl::BitGen& bitgen, bool allow_mul) { + Body b; + b = Body{/*.body=*/static_cast(absl::Uniform( absl::IntervalClosed, bitgen, static_cast(BodyFunction::kAdd), static_cast(BodyFunction::kAny)))}; + // This skews the uniformity of the random generation in favor of add. We + // only need to ensure that all the cases are tested. + if (!allow_mul && b.func == BodyFunction::kMul) { + b.func = BodyFunction::kAdd; + } + return b; } template @@ -746,7 +754,9 @@ TYPED_TEST(StablehloReduceWindowTest, FuzzyTest) { const int rank = absl::Uniform(absl::IntervalClosed, bitgen, 1, 3); ReduceWindowOpModel model; - Body body = Body::GetRandomSupported(bitgen); + // To avoid reduction overflows, we only test mul with floating point types. + Body body = Body::GetRandomSupported( + bitgen, /*allow_mul=*/std::is_floating_point::value); model.SetInput( /*shape=*/RandomVector(bitgen, rank, /*min=*/1, /*max=*/10), bitgen, /*min=*/-5, /*max=*/5); diff --git a/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h b/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h index c514587a394014..d5cd7cc640e4de 100644 --- a/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h +++ b/tensorflow/lite/kernels/stablehlo_reduce_window_test_util.h @@ -91,6 +91,9 @@ inline std::vector DilateShape(std::vector shape, for (size_t i = 0; i < shape.size(); ++i) { shape[i] = (shape[i] - 1) * dilations[i] + 1; } + if (absl::c_any_of(shape, [](auto s) { return s <= 0; })) { + absl::c_fill(shape, 0); + } return shape; } @@ -100,6 +103,10 @@ Tensor Dilate(const Tensor& input, const std::vector& dilations, Tensor output = Tensor::FromShape(DilateShape(input.shape, dilations), padding_value); + if (absl::c_all_of(output.shape, [](auto s) { return s == 0; })) { + return output; + } + const std::vector strides = input.Strides(); const std::vector output_strides = output.Strides(); const std::vector safe_dilations = ExtendToMaxDim(dilations); @@ -142,6 +149,9 @@ inline std::vector PadCropShape(std::vector shape, for (size_t i = 0; i < shape.size(); ++i) { shape[i] = shape[i] + padding[2 * i] + padding[2 * i + 1]; } + if (absl::c_any_of(shape, [](auto s) { return s <= 0; })) { + absl::c_fill(shape, 0); + } return shape; } @@ -160,6 +170,10 @@ Tensor Pad(const Tensor& input, const std::vector& padding, Tensor output = Tensor::FromShape( PadCropShape(input.shape, safe_padding), padding_value); + if (absl::c_all_of(output.shape, [](auto s) { return s == 0; })) { + return output; + } + const std::vector strides = input.Strides(); const std::vector output_strides = output.Strides(); const std::vector safe_input_shape = ExtendToMaxDim(input.shape); @@ -209,6 +223,10 @@ Tensor Crop(const Tensor& input, const std::vector& cropping) { Tensor output = Tensor::FromShape(PadCropShape(input.shape, safe_cropping)); + if (absl::c_all_of(output.shape, [](auto s) { return s == 0; })) { + return output; + } + const std::vector strides = input.Strides(); const std::vector output_strides = output.Strides(); const std::vector safe_output_shape = ExtendToMaxDim(output.shape); diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 78d4fe18e39acb..4781eae5dab108 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -1100,24 +1100,38 @@ class SingleOpTest : public ::testing::TestWithParam { } }; +// Maps the native C++ types to the corresponding TFLite tensor type enum +// values. +template +struct TensorTypeFor; + +#define TFLITE_TENSOR_TYPE_ASSOC(CPP_TYPE, TENSORTYPE_VALUE) \ + template <> \ + struct TensorTypeFor { \ + static constexpr TensorType value = TENSORTYPE_VALUE; \ + }; + +TFLITE_TENSOR_TYPE_ASSOC(bool, TensorType_BOOL); +TFLITE_TENSOR_TYPE_ASSOC(int8_t, TensorType_INT8); +TFLITE_TENSOR_TYPE_ASSOC(int16_t, TensorType_INT16); +TFLITE_TENSOR_TYPE_ASSOC(int32_t, TensorType_INT32); +TFLITE_TENSOR_TYPE_ASSOC(int64_t, TensorType_INT64); +TFLITE_TENSOR_TYPE_ASSOC(uint8_t, TensorType_UINT8); +TFLITE_TENSOR_TYPE_ASSOC(uint16_t, TensorType_UINT16); +TFLITE_TENSOR_TYPE_ASSOC(uint32_t, TensorType_UINT32); +TFLITE_TENSOR_TYPE_ASSOC(uint64_t, TensorType_UINT64); +TFLITE_TENSOR_TYPE_ASSOC(TfLiteFloat16, TensorType_FLOAT16); +TFLITE_TENSOR_TYPE_ASSOC(Eigen::half, TensorType_FLOAT16); +TFLITE_TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32); +TFLITE_TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64); +TFLITE_TENSOR_TYPE_ASSOC(std::string, TensorType_STRING); + +#undef TFLITE_TENSOR_TYPE_ASSOC + // Returns the corresponding TensorType given the type T. template -TensorType GetTensorType() { - if (std::is_same::value) return TensorType_FLOAT32; - if (std::is_same::value) return TensorType_FLOAT16; - if (std::is_same::value) return TensorType_FLOAT16; - if (std::is_same::value) return TensorType_FLOAT64; - if (std::is_same::value) return TensorType_INT8; - if (std::is_same::value) return TensorType_INT16; - if (std::is_same::value) return TensorType_UINT16; - if (std::is_same::value) return TensorType_INT32; - if (std::is_same::value) return TensorType_UINT32; - if (std::is_same::value) return TensorType_INT64; - if (std::is_same::value) return TensorType_UINT64; - if (std::is_same::value) return TensorType_UINT8; - if (std::is_same::value) return TensorType_STRING; - if (std::is_same::value) return TensorType_BOOL; - return TensorType_MIN; // default value +constexpr TensorType GetTensorType() { + return TensorTypeFor::value; } // Strings have a special implementation that is in test_util.cc diff --git a/tensorflow/lite/kernels/variants/BUILD b/tensorflow/lite/kernels/variants/BUILD index 13cb9d1567b297..3a806135fa7a50 100644 --- a/tensorflow/lite/kernels/variants/BUILD +++ b/tensorflow/lite/kernels/variants/BUILD @@ -179,6 +179,24 @@ cc_test( ], ) +cc_test( + name = "list_push_back_test", + srcs = ["list_kernels/list_push_back_test.cc"], + deps = [ + ":list_ops_lib", + ":tensor_array", + ":test_util", + "//tensorflow/lite:type_to_tflitetype", + "//tensorflow/lite/core/c:c_api_types", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "variant_add_n_test", srcs = ["list_kernels/variant_add_n_test.cc"], @@ -198,13 +216,12 @@ cc_test( ) cc_test( - name = "list_push_back_test", - srcs = ["list_kernels/list_push_back_test.cc"], + name = "variant_zeros_like_test", + srcs = ["list_kernels/variant_zeros_like_test.cc"], deps = [ ":list_ops_lib", ":tensor_array", ":test_util", - "//tensorflow/lite:type_to_tflitetype", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels:kernel_util", diff --git a/tensorflow/lite/kernels/variants/list_kernels/list_reserve.cc b/tensorflow/lite/kernels/variants/list_kernels/list_reserve.cc index 094bf38104caa6..7637a4064a5451 100644 --- a/tensorflow/lite/kernels/variants/list_kernels/list_reserve.cc +++ b/tensorflow/lite/kernels/variants/list_kernels/list_reserve.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "tensorflow/lite/array.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -21,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/kernels/variants/list_ops_util.h" #include "tensorflow/lite/kernels/variants/tensor_array.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace variants { @@ -46,23 +49,132 @@ TfLiteType ConvertTensorType(TensorType src) { } } -constexpr int kElementShapeInput = 0; -constexpr int kNumElementsInput = 1; constexpr int kListOut = 0; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); +struct SemanticOutType { + TfLiteType element_type; + IntArrayUniquePtr element_shape; + int num_elements; +}; + +class ReserveSemantic { + public: + ReserveSemantic(TfLiteContext* context, TfLiteNode* node) + : context_(context), node_(node) {} + + constexpr static int kElementShapeInput = 0; + constexpr static int kNumElementsInput = 1; + + TfLiteStatus CheckInputs() const { + TF_LITE_ENSURE_EQ(context_, NumInputs(node_), 2); + const TfLiteTensor* element_shape; + TF_LITE_ENSURE_OK( + context_, + GetInputSafe(context_, node_, kElementShapeInput, &element_shape)); + TF_LITE_ENSURE(context_, element_shape->type == kTfLiteInt32); + const TfLiteTensor* num_elements; + TF_LITE_ENSURE_OK(context_, GetInputSafe(context_, node_, kNumElementsInput, + &num_elements)); + TF_LITE_ENSURE_TYPES_EQ(context_, num_elements->type, kTfLiteInt32); + return kTfLiteOk; + } + + TfLiteStatus Compute(SemanticOutType& result) const { + // Parse element type from custom options. + auto* options = + reinterpret_cast(node_->custom_initial_data); + TfLiteType element_type = ConvertTensorType(options->element_type); + TF_LITE_ENSURE(context_, element_type != kTfLiteNoType); + + const TfLiteTensor* num_elements; + TF_LITE_ENSURE_OK(context_, GetInputSafe(context_, node_, kNumElementsInput, + &num_elements)); + TF_LITE_ENSURE_TYPES_EQ(context_, num_elements->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context_, num_elements->dims->size, 0); + const int num_elements_value = num_elements->data.i32[0]; + TF_LITE_ENSURE(context_, num_elements_value >= 0); + + // Create int array representing constraint on list's constituent elements. + const TfLiteTensor* element_shape_tensor; + TF_LITE_ENSURE_OK(context_, + GetInputSafe(context_, node_, kElementShapeInput, + &element_shape_tensor)); + IntArrayUniquePtr element_shape = TensorAsShape(*element_shape_tensor); + + result = SemanticOutType{element_type, std::move(element_shape), + num_elements_value}; + return kTfLiteOk; + } - const TfLiteTensor* element_shape; - TF_LITE_ENSURE_OK( - context, GetInputSafe(context, node, kElementShapeInput, &element_shape)); - TF_LITE_ENSURE(context, element_shape->type == kTfLiteInt32); + TfLiteStatus PopulateOutput(TensorArray* const output) const { + return kTfLiteOk; + } + + private: + TfLiteContext* const context_; + TfLiteNode* const node_; +}; + +class ZerosLikeSemantic { + public: + ZerosLikeSemantic(TfLiteContext* context, TfLiteNode* node) + : context_(context), node_(node) {} + + constexpr static int kListInput = 0; + + TfLiteStatus CheckInputs() const { + TF_LITE_ENSURE_EQ(context_, NumInputs(node_), 1); + const TfLiteTensor* list_input; + TF_LITE_ENSURE_OK(context_, + GetInputSafe(context_, node_, kListInput, &list_input)); + TF_LITE_ENSURE(context_, list_input->type == kTfLiteVariant); + return kTfLiteOk; + } - const TfLiteTensor* num_elements; - TF_LITE_ENSURE_OK( - context, GetInputSafe(context, node, kNumElementsInput, &num_elements)); - TF_LITE_ENSURE_TYPES_EQ(context, num_elements->type, kTfLiteInt32); + TfLiteStatus Compute(SemanticOutType& result) const { + const TfLiteTensor* list_input; + TF_LITE_ENSURE_OK(context_, + GetInputSafe(context_, node_, kListInput, &list_input)); + const TensorArray* const input = + reinterpret_cast(list_input->data.data); + + result = SemanticOutType{input->ElementType(), + BuildTfLiteArray(*input->ElementShape()), + input->NumElements()}; + return kTfLiteOk; + } + TfLiteStatus PopulateOutput(TensorArray* const output) const { + const TfLiteTensor* list_input; + TF_LITE_ENSURE_OK(context_, + GetInputSafe(context_, node_, kListInput, &list_input)); + const TensorArray* const input = + reinterpret_cast(list_input->data.data); + for (int i = 0; i < input->NumElements(); ++i) { + const TfLiteTensor* const at = input->At(i); + if (at == nullptr) continue; + // Tensorflow supports lazy allocation in this case which is not possible + // with tflite tensors. If this proves to be a performance bottleneck we + // can investigate storing more info in TensorArray putting off allocation + // for later. + TensorUniquePtr output_at = BuildTfLiteTensor( + at->type, BuildTfLiteArray(*at->dims), kTfLiteDynamic); + memset(output_at->data.data, 0, output_at->bytes); + TF_LITE_ENSURE(context_, output->Set(i, std::move(output_at))); + } + return kTfLiteOk; + } + + private: + TfLiteContext* const context_; + TfLiteNode* const node_; +}; + +template +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const Semantic sem(context, node); + TF_LITE_ENSURE_OK(context, sem.CheckInputs()); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kListOut, &output)); TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteVariant); @@ -70,40 +182,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - // Parse element type from custom options. - auto* options = - reinterpret_cast(node->custom_initial_data); - TfLiteType element_type = ConvertTensorType(options->element_type); - TF_LITE_ENSURE(context, element_type != kTfLiteNoType); - - const TfLiteTensor* num_elements; - TF_LITE_ENSURE_OK( - context, GetInputSafe(context, node, kNumElementsInput, &num_elements)); - TF_LITE_ENSURE_TYPES_EQ(context, num_elements->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, num_elements->dims->size, 0); - const int num_elements_value = num_elements->data.i32[0]; - TF_LITE_ENSURE(context, num_elements_value >= 0); - - // Create int array representing constraint on list's constituent elements. - const TfLiteTensor* element_shape_tensor; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kElementShapeInput, - &element_shape_tensor)); - IntArrayUniquePtr element_shape = TensorAsShape(*element_shape_tensor); + const Semantic sem(context, node); + SemanticOutType data; + TF_LITE_ENSURE_OK(context, sem.Compute(data)); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kListOut, &output)); // Construct new `TensorArray` underneath the output tensor. - TfLiteStatus stat = - TfLiteTensorVariantRealloc( - output, std::move(element_type), std::move(element_shape)); + TfLiteStatus stat = TfLiteTensorVariantRealloc( + output, data.element_type, std::move(data.element_shape)); TF_LITE_ENSURE_OK(context, stat); // Set size of array. - TensorArray* arr = + TensorArray* const arr = static_cast(static_cast(output->data.data)); - arr->Resize(num_elements_value); + arr->Resize(data.num_elements); + TF_LITE_ENSURE_OK(context, sem.PopulateOutput(arr)); return kTfLiteOk; } @@ -111,8 +208,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace list_reserve TfLiteRegistration* Register_LIST_RESERVE() { - static TfLiteRegistration r = {nullptr, nullptr, list_reserve::Prepare, - list_reserve::Eval}; + static TfLiteRegistration r = { + nullptr, nullptr, list_reserve::Prepare, + list_reserve::Eval}; + return &r; +} + +TfLiteRegistration* Register_VARIANT_ZEROS_LIKE() { + static TfLiteRegistration r = { + nullptr, nullptr, list_reserve::Prepare, + list_reserve::Eval}; return &r; } diff --git a/tensorflow/lite/kernels/variants/list_kernels/variant_zeros_like_test.cc b/tensorflow/lite/kernels/variants/list_kernels/variant_zeros_like_test.cc new file mode 100644 index 00000000000000..54647833550bfa --- /dev/null +++ b/tensorflow/lite/kernels/variants/list_kernels/variant_zeros_like_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/variants/list_kernels/test_util.h" +#include "tensorflow/lite/kernels/variants/list_ops_lib.h" +#include "tensorflow/lite/kernels/variants/tensor_array.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace variants { +namespace ops { +namespace { + +using ::testing::AllOf; +using ::testing::Combine; +using ::testing::ValuesIn; +using ::tflite::variants::TensorArray; + +class VariantZerosLikeModel : public ListOpModel { + public: + explicit VariantZerosLikeModel() { + list_input_ = AddInput({TensorType_VARIANT, {}}); + list_output_ = AddOutput({TensorType_VARIANT, {}}); + SetCustomOp("VariantZerosLike", {}, Register_VARIANT_ZEROS_LIKE); + BuildInterpreter({{}}); + } + + const TensorArray* GetOutputTensorArray() { + TfLiteTensor* tensor = interpreter_->tensor(list_output_); + TFLITE_CHECK(tensor != nullptr && tensor->type == kTfLiteVariant && + tensor->allocation_type == kTfLiteVariantObject); + return static_cast( + static_cast(tensor->data.data)); + } + + int list_input_; + int list_output_; +}; + +using VariantZerosLikeTestParam = std::tuple, TfLiteType, int>; +class VariantZerosLikeTest + : public testing::TestWithParam { + public: + enum { kShape, kType, kLen }; +}; + +TEST_P(VariantZerosLikeTest, OutputsEmptyListWithSameAttrs) { + const auto& param = GetParam(); + const std::vector& shape = std::get(param); + const TfLiteType t = std::get(param); + const int len = std::get(param); + VariantZerosLikeModel m; + m.PopulateListTensor(m.list_input_, shape, len, t); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + const TensorArray* const out = m.GetOutputTensorArray(); + ASSERT_EQ(out->NumElements(), len); + ASSERT_EQ(out->ElementType(), t); + ASSERT_THAT(out->ElementShape(), DimsAre(shape)); + for (int i = 0; i < len; ++i) { + EXPECT_EQ(out->At(i), nullptr); + } +} + +using VariantZerosLikeItemTestParam = std::tuple>; +class VariantZerosLikeItemTest + : public testing::TestWithParam { + public: + enum { kLen, kShape }; +}; + +TEST_P(VariantZerosLikeItemTest, OutputsEmptyListContainsZeroedElement) { + const auto& param = GetParam(); + const int len = std::get(param); + const std::vector& item_shape = std::get(param); + VariantZerosLikeModel m; + m.PopulateListTensor(m.list_input_, {}, len, kTfLiteInt32); + const int num_elements = NumElements(item_shape.data(), item_shape.size()); + m.ListSetItem(m.list_input_, 0, item_shape, kTfLiteInt32, + std::vector(num_elements, 1).data()); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + const TensorArray* const out = m.GetOutputTensorArray(); + ASSERT_EQ(out->NumElements(), len); + ASSERT_EQ(out->ElementType(), kTfLiteInt32); + ASSERT_THAT(out->ElementShape(), DimsAre({})); + const TfLiteTensor* const zero = out->At(0); + ASSERT_NE(zero, nullptr); + EXPECT_THAT(zero, AllOf(DimsAre(item_shape), IsAllocatedAs(kTfLiteInt32), + FilledWith(0))); + for (int i = 1; i < len; ++i) { + EXPECT_EQ(out->At(i), nullptr); + } +} + +INSTANTIATE_TEST_SUITE_P(VariantZerosLikeTests, VariantZerosLikeTest, + Combine(ValuesIn(std::vector>{ + {}, {-1}, {2, 2}, {3, 3, 3}}), + ValuesIn({kTfLiteInt32, kTfLiteInt64, + kTfLiteFloat32, kTfLiteBool}), + ValuesIn({0, 2, 10}))); + +INSTANTIATE_TEST_SUITE_P(VariantZerosLikeTests, VariantZerosLikeItemTest, + Combine(ValuesIn({1, 2, 10}), + ValuesIn(std::vector>{ + {1}, {2, 2}, {3, 3, 3}}))); + +} // namespace +} // namespace ops +} // namespace variants +} // namespace tflite diff --git a/tensorflow/lite/kernels/variants/list_ops_lib.h b/tensorflow/lite/kernels/variants/list_ops_lib.h index 33f1b16a8e59c8..52efd3abae82e8 100644 --- a/tensorflow/lite/kernels/variants/list_ops_lib.h +++ b/tensorflow/lite/kernels/variants/list_ops_lib.h @@ -49,6 +49,8 @@ TfLiteRegistration* Register_LIST_PUSH_BACK(); TfLiteRegistration* Register_VARIANT_ADD_N(); +TfLiteRegistration* Register_VARIANT_ZEROS_LIKE(); + } // namespace ops } // namespace variants } // namespace tflite diff --git a/tensorflow/lite/kernels/variants/py/BUILD b/tensorflow/lite/kernels/variants/py/BUILD index 4373d1e389f5fe..da3e0dfb81b650 100644 --- a/tensorflow/lite/kernels/variants/py/BUILD +++ b/tensorflow/lite/kernels/variants/py/BUILD @@ -22,11 +22,12 @@ py_strict_test( tags = ["nochromiumos_arm"], deps = [ ":register_list_ops_py", + "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/lite/python:interpreter", "//tensorflow/python/ops:list_ops", "//tensorflow/python/platform:test", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/lite/kernels/variants/register_list_ops.cc b/tensorflow/lite/kernels/variants/register_list_ops.cc index 9158c6f8b9f74d..36247e35960477 100644 --- a/tensorflow/lite/kernels/variants/register_list_ops.cc +++ b/tensorflow/lite/kernels/variants/register_list_ops.cc @@ -32,6 +32,7 @@ void RegisterListOps(MutableOpResolver* resolver) { resolver->AddCustom("TensorListPopBack", Register_LIST_POP_BACK()); resolver->AddCustom("TensorListPushBack", Register_LIST_PUSH_BACK()); resolver->AddCustom("VariantAddN", Register_VARIANT_ADD_N()); + resolver->AddCustom("VariantZerosLike", Register_VARIANT_ZEROS_LIKE()); } } // namespace ops diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index adb0409b2fb791..1547947dc80fe1 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -59,13 +59,14 @@ py_strict_test( ], deps = [ ":interpreter", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/lite/python/metrics", "//tensorflow/lite/python/testdata:_pywrap_test_registerer", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:resource_loader", - "//third_party/py/numpy", ], ) @@ -122,6 +123,7 @@ py_strict_test( python_version = "PY3", deps = [ ":test_util", + #internal proto upb dep "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:gfile", @@ -149,14 +151,16 @@ py_strict_test( ":convert", ":test_util", ":tflite_convert_main_lib", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:tf2", "//tensorflow/python/client:session", "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", @@ -168,7 +172,6 @@ py_strict_test( "//tensorflow/python/saved_model:save", "//tensorflow/python/trackable:autotrackable", "//tensorflow/python/training:training_util", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) @@ -197,11 +200,11 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", - "//tensorflow/python/framework", "//tensorflow/python/framework:byte_swap_tensor", "//tensorflow/python/framework:convert_to_constants", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:versions", "//tensorflow/python/platform:gfile", @@ -240,6 +243,8 @@ py_strict_test( ":lite_constants", ":schema_py", ":util", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/python/client:session", "//tensorflow/python/eager:context", @@ -262,7 +267,6 @@ py_strict_test( "//tensorflow/python/platform:resource_loader", "//tensorflow/python/saved_model", "//tensorflow/python/training:training_util", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) @@ -289,6 +293,8 @@ py_strict_test( ":schema_py", ":test_util", ":util", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_options_proto_py", "//tensorflow/lite/python/testdata:_pywrap_test_registerer", @@ -308,7 +314,6 @@ py_strict_test( "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:save_options", "//tensorflow/python/trackable:autotrackable", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@pypi_jax//:pkg", ], @@ -351,14 +356,16 @@ py_strict_test( ":interpreter", ":lite", ":test_util", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/lite/python/testdata:double_op", "//tensorflow/python/client:session", "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", @@ -368,7 +375,6 @@ py_strict_test( "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model", "//tensorflow/python/trackable:autotrackable", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) @@ -409,6 +415,8 @@ py_strict_test( ], deps = [ ":util", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/lite/tools:flatbuffer_utils", "//tensorflow/python/client:session", @@ -420,7 +428,6 @@ py_strict_test( "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:while_loop", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) @@ -561,6 +568,7 @@ py_strict_test( visibility = ["//visibility:public"], deps = [ ":convert_saved_model", + #internal proto upb dep "//tensorflow/python/client:session", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", @@ -644,6 +652,7 @@ py_strict_test( python_version = "PY3", deps = [ ":analyzer", + #internal proto upb dep "//tensorflow:tensorflow_py", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/lite/python/authoring/BUILD b/tensorflow/lite/python/authoring/BUILD index b6f13fc4f9a8be..d08bb377620a34 100644 --- a/tensorflow/lite/python/authoring/BUILD +++ b/tensorflow/lite/python/authoring/BUILD @@ -27,6 +27,7 @@ py_strict_test( srcs_version = "PY2AND3", deps = [ ":authoring", + #internal proto upb dep "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 52eb953a4a6638..d3aea0399683af 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -308,10 +308,18 @@ def testCreationCounter(self, increase_call): class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): + # Model must have at least 7 bytes to hold model identifier + def testTooShortModelContent(self): + with self.assertRaisesRegex( + ValueError, + 'Model provided must have at least 7 bytes to hold identifier.', + ): + interpreter_wrapper.Interpreter(model_content=b'short') + def testInvalidModelContent(self): with self.assertRaisesRegex(ValueError, 'Model provided has model identifier \''): - interpreter_wrapper.Interpreter(model_content=b'garbage') + interpreter_wrapper.Interpreter(model_content=b'wrong_identifier') def testInvalidModelFile(self): with self.assertRaisesRegex(ValueError, diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 956825ba639b6f..d3dbfcf2ce5b21 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -1160,7 +1160,7 @@ def _validate_inference_input_output_types(self, quant_mode): if quant_mode.is_post_training_int16x8_quantization(): all_types = default_types + [_dtypes.int16] else: - all_types = default_types + [_dtypes.int8, _dtypes.uint8] + all_types = default_types + [_dtypes.int8, _dtypes.uint8, _dtypes.int16] if ( self.inference_input_type not in all_types or self.inference_output_type not in all_types diff --git a/tensorflow/lite/python/metrics/BUILD b/tensorflow/lite/python/metrics/BUILD index 1dc0a837124aca..cc86ea7b46dc50 100644 --- a/tensorflow/lite/python/metrics/BUILD +++ b/tensorflow/lite/python/metrics/BUILD @@ -69,6 +69,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":metrics_wrapper", + #internal proto upb dep "//tensorflow:tensorflow_py", "//tensorflow/lite/python:convert", "//tensorflow/lite/python:lite", @@ -131,6 +132,9 @@ py_strict_test( deps = [ ":converter_error_data_proto_py", ":metrics", + "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/lite/python:convert", @@ -138,9 +142,9 @@ py_strict_test( "//tensorflow/python/client:session", "//tensorflow/python/eager:context", "//tensorflow/python/eager:monitoring", - "//tensorflow/python/framework", "//tensorflow/python/framework:convert_to_constants", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", @@ -152,8 +156,6 @@ py_strict_test( "//tensorflow/python/platform:resource_loader", "//tensorflow/python/saved_model", "//tensorflow/python/trackable:autotrackable", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index e8bf37b60f77c2..9df94b4a9a8054 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -88,6 +88,9 @@ py_strict_test( tags = ["no_oss"], deps = [ ":calibrator", + "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/lite/python:lite", "//tensorflow/lite/python:schema_py", @@ -96,7 +99,5 @@ py_strict_test( "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:resource_loader", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index 5faaea63bc77af..f05dec25d9cab6 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -148,8 +148,8 @@ tf_custom_op_py_strict_library( srcs_version = "PY3", deps = [ ":gen_double_op_wrapper", - "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:load_library", "//tensorflow/python/platform:resource_loader", ], ) diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 6bffeadfbad9cb..382462f938d93b 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -459,7 +459,7 @@ enum BuiltinOperator : int32 { STABLEHLO_CONVERT = 192, // WARNING: No runtime support STABLEHLO_DYNAMIC_SLICE = 193, // WARNING: No runtime support STABLEHLO_DYNAMIC_UPDATE_SLICE = 194, // WARNING: No runtime support - STABLEHLO_PAD = 195, // WARNING: No runtime support + STABLEHLO_PAD = 195, STABLEHLO_IOTA = 196, // WARNING: No runtime support STABLEHLO_DOT_GENERAL = 197, // WARNING: No runtime support STABLEHLO_REDUCE_WINDOW = 198, diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 23d1ed50486507..f1299b56ac7502 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -15,23 +15,41 @@ limitations under the License. #include "tensorflow/lite/simple_memory_arena.h" -#include -#include - #include +#include +#include +#include #include -#include #include -#include #include #include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/macros.h" + #ifdef TF_LITE_TENSORFLOW_PROFILER #include "tensorflow/lite/tensorflow_profiler_logger.h" #endif // TF_LITE_TENSORFLOW_PROFILER +#if defined(__ANDROID__) +// Android has C11 aligned_alloc only with API 28 or newer, even with C++17 or +// C11 compilation (this is a non-standard behavior). +#define TF_LITE_HAS_ALIGNED_ALLOC (__ANDROID_API__ >= 28) +#elif defined(__APPLE__) +// Apple does not provide aligned_alloc, even with C++17 or C11 compilation +// (this is a non-standard behavior). +#define TF_LITE_HAS_ALIGNED_ALLOC 0 +#elif defined(_WIN32) +// Windows does not provide aligned_alloc, even with C++17 or C11 compilation +// (this is a non-standard behavior). However, it provides _aligned_malloc, +// _aligned_realloc, and _aligned_free, with a slightly different behavior than +// the C11/C++17 standard functions (size requirement, and free function name.) +#define TF_LITE_HAS_ALIGNED_ALLOC 0 +#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L +// C++17 or C11 has (std::)aligned_alloc +#define TF_LITE_HAS_ALIGNED_ALLOC 1 +#endif + namespace { template @@ -40,10 +58,135 @@ T AlignTo(size_t alignment, T offset) { : offset + (alignment - offset % alignment); } +// Allocates memory and aligns it to the specified size. Returns a pair of the +// allocation pointer and the aligned pointer. +tflite::PointerAlignedPointerPair AlignedAlloc(size_t size, size_t alignment); + +// Frees up aligned memory. +void AlignedFree(const tflite::PointerAlignedPointerPair& buffer); + +// Reallocates aligned memory +// +// The function either extends the memory allocation in-place, or if that is not +// possible a new allocation is created, the data is copied, and the old buffer +// is deallocated. It is an error to change the alignment during reallocation. +// If the previous allocation is null, this is equivalent to AlignedAlloc. +// Returns pointers to the new allocation. +tflite::PointerAlignedPointerPair AlignedRealloc( + const tflite::PointerAlignedPointerPair& old_buffer, size_t old_size, + size_t new_size, size_t alignment); + +#if defined(_WIN32) +// On Windows provides _aligned_malloc, _aligned_free, and +// _aligned_realloc, use them to implement the Aligned functions. + +tflite::PointerAlignedPointerPair AlignedAlloc(size_t size, size_t alignment) { + char* pointer = reinterpret_cast(_aligned_malloc(size, alignment)); + char* aligned_ptr = pointer; + return {pointer, aligned_ptr}; +} + +void AlignedFree(const tflite::PointerAlignedPointerPair& buffer) { + _aligned_free(buffer.pointer); +} + +tflite::PointerAlignedPointerPair AlignedRealloc( + const tflite::PointerAlignedPointerPair& old_buffer, size_t old_size, + size_t new_size, size_t alignment) { + char* pointer = reinterpret_cast( + _aligned_realloc(old_buffer.pointer, new_size, alignment)); + char* aligned_ptr = pointer; + return {pointer, aligned_ptr}; +} +#else +// Default implementation: Use malloc, allocating extra memory, and align the +// pointer in the allocated buffer. + +tflite::PointerAlignedPointerPair AlignedAlloc(size_t size, size_t alignment) { +#if TF_LITE_HAS_ALIGNED_ALLOC + // (std::)aligned_alloc requires size to be multiple of alignment. + // TODO(b/311495100): when bug is fixed, remove `size + alignment - 1` part. + const size_t allocation_size = AlignTo(alignment, size + alignment - 1); + char* pointer = + reinterpret_cast(::aligned_alloc(alignment, allocation_size)); + char* aligned_ptr = pointer; +#else + // TODO(b/311495100): when bug is fixed, change this to + // `size + std::max(size_t{0}, alignment - alignof(std::max_align_t))` + const size_t allocation_size = size + alignment - 1; + char* pointer = reinterpret_cast(std::malloc(allocation_size)); + char* aligned_ptr = reinterpret_cast( + AlignTo(alignment, reinterpret_cast(pointer))); +#endif +#if defined(__clang__) +#if __has_feature(memory_sanitizer) + std::memset(pointer, 0, allocation_size); +#endif +#endif + return {pointer, aligned_ptr}; +} + +void AlignedFree(const tflite::PointerAlignedPointerPair& buffer) { + std::free(buffer.pointer); +} + +tflite::PointerAlignedPointerPair AlignedRealloc( + const tflite::PointerAlignedPointerPair& old_buffer, size_t old_size, + size_t new_size, size_t alignment) { + tflite::PointerAlignedPointerPair new_buffer = + AlignedAlloc(new_size, alignment); + if (new_size > 0 && old_size > 0) { + // Copy data when both old and new buffers are bigger than 0 bytes. + const size_t copy_amount = std::min(new_size, old_size); + std::memcpy(new_buffer.aligned_pointer, old_buffer.aligned_pointer, + copy_amount); + } + AlignedFree(old_buffer); + return new_buffer; +} +#endif } // namespace namespace tflite { +bool ResizableAlignedBuffer::Resize(size_t new_size) { + if (new_size <= data_size_) { + // Skip reallocation when resizing down. + return false; + } +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/true); + OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), + new_size); + if (data_size_ > 0) { + OnTfLiteArenaDealloc(subgraph_index_, + reinterpret_cast(this), data_size_); + } +#endif + auto new_buffer = AlignedRealloc(buffer_, data_size_, new_size, alignment_); + bool reallocated = (new_buffer.aligned_pointer != buffer_.aligned_pointer); + buffer_ = new_buffer; + data_size_ = new_size; +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/false); +#endif + return reallocated; +} + +void ResizableAlignedBuffer::Release() { + if (buffer_.pointer == nullptr) { + return; + } +#ifdef TF_LITE_TENSORFLOW_PROFILER + OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), + data_size_); +#endif + AlignedFree(buffer_); + buffer_.pointer = nullptr; + buffer_.aligned_pointer = nullptr; + data_size_ = 0; +} + void SimpleMemoryArena::PurgeAfter(int32_t node) { for (int i = 0; i < active_allocs_.size(); ++i) { if (active_allocs_[i].first_node > node) { @@ -91,7 +234,7 @@ TfLiteStatus SimpleMemoryArena::Allocate( TfLiteContext* context, size_t alignment, size_t size, int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc) { - TF_LITE_ENSURE(context, alignment <= arena_alignment_); + TF_LITE_ENSURE(context, alignment <= underlying_buffer_.GetAlignment()); new_alloc->tensor = tensor; new_alloc->first_node = first_node; new_alloc->last_node = last_node; @@ -141,50 +284,13 @@ TfLiteStatus SimpleMemoryArena::Allocate( return kTfLiteOk; } -TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context, - bool* arena_reallocated) { - size_t required_size = RequiredBufferSize(); - if (required_size > underlying_buffer_size_) { - *arena_reallocated = true; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/true); - OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), - required_size); -#endif - char* new_alloc = new char[required_size]; - char* new_underlying_buffer_aligned_ptr = reinterpret_cast( - AlignTo(arena_alignment_, reinterpret_cast(new_alloc))); - - // If the arena had been previously allocated, copy over the old memory. - // Since Alloc pointers are offset based, they will remain valid in the new - // memory block. - if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) { - size_t copy_amount = std::min( - underlying_buffer_.get() + underlying_buffer_size_ - - underlying_buffer_aligned_ptr_, - new_alloc + required_size - new_underlying_buffer_aligned_ptr); - memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_, - copy_amount); - } - -#ifdef TF_LITE_TENSORFLOW_PROFILER - if (underlying_buffer_size_ > 0) { - OnTfLiteArenaDealloc(subgraph_index_, - reinterpret_cast(this), - underlying_buffer_size_); - } -#endif - underlying_buffer_.reset(new_alloc); - underlying_buffer_size_ = required_size; - underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/false); -#endif - } else { - *arena_reallocated = false; - } +TfLiteStatus SimpleMemoryArena::Commit(bool* arena_reallocated) { + // Resize the arena to the high water mark (calculated by Allocate), retaining + // old contents and alignment in the process. Since Alloc pointers are offset + // based, they will remain valid in the new memory block. + *arena_reallocated = underlying_buffer_.Resize(high_water_mark_); committed_ = true; - return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; + return kTfLiteOk; } TfLiteStatus SimpleMemoryArena::ResolveAlloc( @@ -193,11 +299,11 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); TF_LITE_ENSURE(context, - underlying_buffer_size_ >= (alloc.offset + alloc.size)); + underlying_buffer_.GetSize() >= (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + *output_ptr = underlying_buffer_.GetPtr() + alloc.offset; } return kTfLiteOk; } @@ -211,13 +317,7 @@ TfLiteStatus SimpleMemoryArena::ClearPlan() { TfLiteStatus SimpleMemoryArena::ReleaseBuffer() { committed_ = false; -#ifdef TF_LITE_TENSORFLOW_PROFILER - OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - underlying_buffer_size_); -#endif - underlying_buffer_size_ = 0; - underlying_buffer_aligned_ptr_ = nullptr; - underlying_buffer_.reset(); + underlying_buffer_.Release(); return kTfLiteOk; } @@ -229,7 +329,7 @@ TFLITE_ATTRIBUTE_WEAK void DumpArenaInfo( void SimpleMemoryArena::DumpDebugInfo( const std::string& name, const std::vector& execution_plan) const { - tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_size_, + tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_.GetSize(), active_allocs_); } diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 8f8859a6c0d594..7275b3014f3660 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -15,10 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ #define TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ -#include - +#include #include -#include #include #include @@ -55,6 +53,53 @@ struct ArenaAllocWithUsageInterval { } }; +struct PointerAlignedPointerPair { + char* pointer; + char* aligned_pointer; +}; + +class ResizableAlignedBuffer { + public: + ResizableAlignedBuffer(size_t alignment, int subgraph_index) + : buffer_{nullptr, nullptr}, + data_size_(0), + alignment_(alignment), + subgraph_index_(subgraph_index) { + // To silence unused private member warning, only used with + // TF_LITE_TENSORFLOW_PROFILER + (void)subgraph_index_; + } + + ~ResizableAlignedBuffer() { Release(); } + + // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps + // alignment and any existing the data. Returns true when any external + // pointers into the data array need to be adjusted (the buffer was moved). + bool Resize(size_t new_size); + // Releases any allocated memory. + void Release(); + + // Pointer to the data array. + char* GetPtr() const { return buffer_.aligned_pointer; } + // Size of the data array. Note: the allocated memory block might be larger + // due to excess alignment requirements. + size_t GetSize() const { return data_size_; } + // Alignment of the data array. + size_t GetAlignment() const { return alignment_; } + + private: + ResizableAlignedBuffer(const ResizableAlignedBuffer&) = delete; + ResizableAlignedBuffer& operator=(const ResizableAlignedBuffer&) = delete; + ResizableAlignedBuffer(ResizableAlignedBuffer&&) = delete; + ResizableAlignedBuffer& operator=(ResizableAlignedBuffer&&) = delete; + + PointerAlignedPointerPair buffer_; + size_t data_size_; + size_t alignment_; + + int subgraph_index_; +}; + // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is @@ -63,11 +108,9 @@ struct ArenaAllocWithUsageInterval { class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment, int subgraph_index = 0) - : subgraph_index_(subgraph_index), - committed_(false), - arena_alignment_(arena_alignment), + : committed_(false), high_water_mark_(0), - underlying_buffer_size_(0), + underlying_buffer_(arena_alignment, subgraph_index), active_allocs_() {} // Delete all allocs. This should be called when allocating the first node of @@ -99,14 +142,7 @@ class SimpleMemoryArena { int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc); - inline size_t RequiredBufferSize() { - // Add in a small amount of padding to reduce the chance of resize events - // for small allocations. - size_t padding = arena_alignment_; - return arena_alignment_ + high_water_mark_ + padding; - } - - TfLiteStatus Commit(TfLiteContext* context, bool* arena_reallocated); + TfLiteStatus Commit(bool* arena_reallocated); TfLiteStatus ResolveAlloc(TfLiteContext* context, const ArenaAllocWithUsageInterval& alloc, @@ -122,10 +158,10 @@ class SimpleMemoryArena { // again until Commit() is called & tensor allocations are resolved. TfLiteStatus ReleaseBuffer(); - size_t GetBufferSize() const { return underlying_buffer_size_; } + size_t GetBufferSize() const { return underlying_buffer_.GetSize(); } std::intptr_t BasePointer() const { - return reinterpret_cast(underlying_buffer_aligned_ptr_); + return reinterpret_cast(underlying_buffer_.GetPtr()); } // Dumps the memory allocation information of this memory arena (which could @@ -145,16 +181,10 @@ class SimpleMemoryArena { void DumpDebugInfo(const std::string& name, const std::vector& execution_plan) const; - protected: - int subgraph_index_; - private: bool committed_; - size_t arena_alignment_; size_t high_water_mark_; - std::unique_ptr underlying_buffer_; - size_t underlying_buffer_size_; - char* underlying_buffer_aligned_ptr_; + ResizableAlignedBuffer underlying_buffer_; std::vector active_allocs_; }; diff --git a/tensorflow/lite/simple_memory_arena_test.cc b/tensorflow/lite/simple_memory_arena_test.cc index fb21e145b62693..af5a4d8ed668ea 100644 --- a/tensorflow/lite/simple_memory_arena_test.cc +++ b/tensorflow/lite/simple_memory_arena_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { @@ -56,8 +55,8 @@ TEST(SimpleMemoryArenaTest, BasicZeroAlloc) { // The zero-sized alloc should resolve to null. char* resolved_ptr = nullptr; bool reallocated = false; - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); - ASSERT_TRUE(reallocated); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); + EXPECT_FALSE(reallocated); // Don't allocate when zero bytes are needed. EXPECT_EQ(resolved_ptr, nullptr); } @@ -88,7 +87,7 @@ TEST(SimpleMemoryArenaTest, TestClearPlan) { arena.Allocate(&context, 32, 2047, 1, 1, 2, &allocs[1]); arena.Allocate(&context, 32, 2047, 2, 1, 2, &allocs[2]); bool reallocated = false; - arena.Commit(&context, &reallocated); + arena.Commit(&reallocated); ASSERT_TRUE(reallocated); EXPECT_EQ(allocs[0].offset, 0); @@ -101,7 +100,7 @@ TEST(SimpleMemoryArenaTest, TestClearPlan) { arena.Allocate(&context, 32, 1023, 3, 0, 2, &allocs[3]); arena.Allocate(&context, 32, 1023, 4, 1, 2, &allocs[4]); arena.Allocate(&context, 32, 1023, 5, 1, 2, &allocs[5]); - arena.Commit(&context, &reallocated); + arena.Commit(&reallocated); ASSERT_FALSE(reallocated); EXPECT_EQ(allocs[3].offset, 0); @@ -114,7 +113,7 @@ TEST(SimpleMemoryArenaTest, TestClearPlan) { arena.Allocate(&context, 32, 4095, 6, 0, 2, &allocs[6]); arena.Allocate(&context, 32, 4095, 7, 1, 2, &allocs[7]); arena.Allocate(&context, 32, 4095, 8, 1, 2, &allocs[8]); - arena.Commit(&context, &reallocated); + arena.Commit(&reallocated); ASSERT_TRUE(reallocated); EXPECT_EQ(allocs[6].offset, 0); @@ -136,7 +135,7 @@ TEST(SimpleMemoryArenaTest, TestPurgeAllocs) { /*first_node=*/2, /*last_node=*/3, &allocs[2]); bool reallocated = false; - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_TRUE(reallocated); char* resolved_ptr0 = nullptr; char* resolved_ptr1 = nullptr; @@ -167,7 +166,7 @@ TEST(SimpleMemoryArenaTest, TestPurgeAllocs) { arena.PurgeActiveAllocs(4); arena.Allocate(&context, /*alignment=*/32, /*size=*/13, /*tensor=*/3, /*first_node=*/4, /*last_node=*/5, &allocs[4]); - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_EQ(arena.ResolveAlloc(&context, allocs[4], &resolved_ptr3), kTfLiteOk); /* no tensors are allocated at node 4, so tensor 3's offset should be zero.*/ ASSERT_EQ(allocs[4].offset, 0); @@ -190,7 +189,7 @@ TEST(SimpleMemoryArenaTest, TestPurgeAllocs) { */ arena.Allocate(&context, /*alignment=*/32, /*size=*/2047, /*tensor=*/0, /*first_node=*/0, /*last_node=*/2, &allocs[0]); - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_EQ(arena.ResolveAlloc(&context, allocs[3], &resolved_ptr3), kTfLiteOk); ASSERT_EQ(allocs[0].offset, 0); } @@ -209,7 +208,7 @@ TEST(SimpleMemoryArenaTest, TestResetAllocs) { /*first_node=*/2, /*last_node=*/3, &allocs[2]); bool reallocated = false; - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_TRUE(reallocated); char* resolved_ptr0 = nullptr; char* resolved_ptr1 = nullptr; @@ -239,7 +238,7 @@ TEST(SimpleMemoryArenaTest, TestResetAllocs) { */ arena.Allocate(&context, /*alignment=*/32, /*size=*/13, /*tensor=*/0, /*first_node=*/0, /*last_node=*/3, &allocs[3]); - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); /* This is the expected arena after tensor3 has been allocated. * |xxxxxxxxxxxxxxxxx| tensor3 * |xxxxx| tensor2 @@ -275,7 +274,7 @@ TEST(SimpleMemoryArenaTest, TestResetAllocs) { * ___________________ */ - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_EQ(arena.ResolveAlloc(&context, allocs[3], &resolved_ptr3), kTfLiteOk); ASSERT_EQ(allocs[3].offset, 0); } @@ -294,7 +293,7 @@ TEST(SimpleMemoryArenaTest, TestClearBuffer) { // Commit and ensure resolved pointers are not null. bool reallocated = false; - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_TRUE(reallocated); char* resolved_ptr = nullptr; ASSERT_EQ(arena.ResolveAlloc(&context, allocs[0], &resolved_ptr), kTfLiteOk); @@ -311,7 +310,7 @@ TEST(SimpleMemoryArenaTest, TestClearBuffer) { ASSERT_NE(arena.ResolveAlloc(&context, allocs[0], &resolved_ptr), kTfLiteOk); // Commit again and ensure resolved pointers are not null. - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_TRUE(reallocated); ASSERT_NE(arena.BasePointer(), 0); resolved_ptr = nullptr; @@ -337,7 +336,7 @@ TEST_P(BufferAndPlanClearingTest, TestClearBufferAndClearPlan) { arena.Allocate(&context, 32, 2047, 1, 1, 2, &allocs[1]); bool reallocated = false; - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_TRUE(reallocated); if (GetParam()) { @@ -349,15 +348,17 @@ TEST_P(BufferAndPlanClearingTest, TestClearBufferAndClearPlan) { } // Just committing won't work, allocations need to be made again. - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); - ASSERT_TRUE(reallocated); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); + // There was no allocation, the buffer has 0 bytes (was released) and the high + // water mark is 0 (plan was cleared). + EXPECT_FALSE(reallocated); char* resolved_ptr = nullptr; ASSERT_NE(arena.ResolveAlloc(&context, allocs[0], &resolved_ptr), kTfLiteOk); // Re-allocate tensors & commit. arena.Allocate(&context, 32, 2047, 0, 0, 2, &allocs[0]); arena.Allocate(&context, 32, 2047, 1, 1, 2, &allocs[1]); - ASSERT_EQ(arena.Commit(&context, &reallocated), kTfLiteOk); + ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); ASSERT_TRUE(reallocated); // Pointer-resolution now works. diff --git a/tensorflow/lite/testing/op_tests/gather_nd.py b/tensorflow/lite/testing/op_tests/gather_nd.py index 66d8a30033b7bd..37eb052ebff73f 100644 --- a/tensorflow/lite/testing/op_tests/gather_nd.py +++ b/tensorflow/lite/testing/op_tests/gather_nd.py @@ -25,19 +25,40 @@ def make_gather_nd_tests(options): test_parameters = [ { - "params_dtype": [tf.float32, tf.int16, tf.int32, tf.int64, tf.string], + "params_dtype": [ + tf.float32, + tf.int16, + tf.int32, + tf.int64, + tf.string, + tf.bool, + ], "params_shape": [[5, 1]], "indices_dtype": [tf.int16, tf.int32, tf.int64], "indices_shape": [[1, 1]], }, { - "params_dtype": [tf.float32, tf.int16, tf.int32, tf.int64, tf.string], + "params_dtype": [ + tf.float32, + tf.int16, + tf.int32, + tf.int64, + tf.string, + tf.bool, + ], "params_shape": [[5, 5]], "indices_dtype": [tf.int16, tf.int32, tf.int64], "indices_shape": [[2, 1], [2, 2]], }, { - "params_dtype": [tf.float32, tf.int16, tf.int32, tf.int64, tf.string], + "params_dtype": [ + tf.float32, + tf.int16, + tf.int32, + tf.int64, + tf.string, + tf.bool, + ], "params_shape": [[5, 5, 10]], "indices_dtype": [tf.int16, tf.int32, tf.int64], "indices_shape": [[3, 1], [2, 2], [2, 3], [2, 1, 3]], diff --git a/tensorflow/lite/toco/logging/BUILD b/tensorflow/lite/toco/logging/BUILD index 17ec59f24d9f71..83daab2357364a 100644 --- a/tensorflow/lite/toco/logging/BUILD +++ b/tensorflow/lite/toco/logging/BUILD @@ -94,6 +94,7 @@ py_strict_test( deps = [ ":gen_html", ":toco_conversion_log_proto_py", + #internal proto upb dep "//tensorflow/python/framework:test_lib", "//tensorflow/python/lib/io:file_io", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 77094a64b7d45b..a4ea52f9ac37ae 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -88,12 +88,14 @@ cc_library( hdrs = [ "export.h", ], - features = ["-layering_check"], visibility = ["//visibility:public"], deps = [ ":operator", ":types", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/lite:context", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite:util", "//tensorflow/lite/schema:schema_conversion_utils", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/toco:model", @@ -110,11 +112,13 @@ tf_cc_test( srcs = [ "export_test.cc", ], - features = ["-layering_check"], deps = [ ":export", ":operator", + ":types", + "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 12e9d0db25a914..704021241d6e07 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -51,6 +51,7 @@ py_strict_test( deps = [ ":test_utils", ":visualize_lib", + #internal proto upb dep "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", ], @@ -95,11 +96,12 @@ py_strict_test( srcs_version = "PY3", deps = [ ":convert_image_to_csv_lib", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:resource_loader", - "//third_party/py/numpy", ], ) @@ -159,6 +161,7 @@ py_strict_test( deps = [ ":flatbuffer_utils", ":test_utils", + #internal proto upb dep "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", ], diff --git a/tensorflow/lite/tools/build_aar.sh b/tensorflow/lite/tools/build_aar.sh index 1847b794d757aa..644bd08f6dda8d 100755 --- a/tensorflow/lite/tools/build_aar.sh +++ b/tensorflow/lite/tools/build_aar.sh @@ -90,12 +90,14 @@ function generate_tflite_aar { popd > /dev/null # TODO(b/254278688): Enable 'xnn_enable_arm_fp16' with toolchain upgrade. # TODO(b/297897797): Enable 'xnn_enable_arm_i8mm' with toolchain upgrade. - bazel ${CACHE_DIR_FLAG} build -c opt --cxxopt='--std=c++17' \ + # TODO: b/315114212 - Remove `xnn_enable_vnni` when the compiler supports it. + bazel ${CACHE_DIR_FLAG} build -c opt --config=opt --cxxopt='--std=c++17' \ --fat_apk_cpu=${TARGET_ARCHS} \ --define=android_dexmerger_tool=d8_dexmerger \ --define=android_incremental_dexing_tool=d8_dexbuilder\ --define=xnn_enable_arm_fp16=false \ --define=xnn_enable_arm_i8mm=false \ + --define=xnn_enable_avxvnni=false \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ //tmp:tensorflow-lite @@ -130,12 +132,14 @@ function generate_flex_aar { # Build the aar package. # TODO(b/254278688): Enable 'xnn_enable_arm_fp16' with toolchain upgrade. # TODO(b/297897797): Enable 'xnn_enable_arm_i8mm' with toolchain upgrade. - bazel ${CACHE_DIR_FLAG} build -c opt --cxxopt='--std=c++17' \ + # TODO: b/315114212 - Remove `xnn_enable_vnni` when the compiler supports it. + bazel ${CACHE_DIR_FLAG} build -c opt --config=opt --cxxopt='--std=c++17' \ --fat_apk_cpu=${TARGET_ARCHS} \ --define=android_dexmerger_tool=d8_dexmerger \ --define=android_incremental_dexing_tool=d8_dexbuilder\ --define=xnn_enable_arm_fp16=false \ --define=xnn_enable_arm_i8mm=false \ + --define=xnn_enable_avxvnni=false \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ //tmp:tensorflow-lite-select-tf-ops @@ -191,14 +195,16 @@ fi # Build the standard aar package of no models provided. # TODO(b/254278688): Enable 'xnn_enable_arm_fp16' with toolchain upgrade. # TODO(b/297897797): Enable 'xnn_enable_arm_i8mm' with toolchain upgrade. +# TODO: b/315114212 - Remove `xnn_enable_vnni` when the compiler supports it. if [ -z ${FLAG_MODELS} ]; then - bazel ${CACHE_DIR_FLAG} build -c opt --cxxopt='--std=c++17' \ + bazel ${CACHE_DIR_FLAG} build -c opt --config=opt --cxxopt='--std=c++17' \ --config=monolithic \ --fat_apk_cpu=${TARGET_ARCHS} \ --define=android_dexmerger_tool=d8_dexmerger \ --define=android_incremental_dexing_tool=d8_dexbuilder\ --define=xnn_enable_arm_fp16=false \ --define=xnn_enable_arm_i8mm=false \ + --define=xnn_enable_avxvnni=false \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ //tensorflow/lite/java:tensorflow-lite diff --git a/tensorflow/lite/tools/build_aar_with_docker.sh b/tensorflow/lite/tools/build_aar_with_docker.sh index 27624f943d4f1a..fbcc9325cd5fa0 100755 --- a/tensorflow/lite/tools/build_aar_with_docker.sh +++ b/tensorflow/lite/tools/build_aar_with_docker.sh @@ -104,14 +104,16 @@ else cd /tensorflow_src # Run configure. + # -Wno-c++20-designator can be removed once tf supports C++20. + # -Wno-gnu-inline-cpp-without-extern is needed for NEON2SSE. Can remove after + # https://github.com/intel/ARM_NEON_2_x86_SSE/issues/57 is resolved. configs=( '/usr/bin/python3' '/usr/lib/python3/dist-packages' 'N' 'N' 'N' - 'N' - '-march=native -Wno-sign-compare' + '-Wno-sign-compare -Wno-c++20-designator -Wno-gnu-inline-cpp-without-extern' 'y' '/android/sdk' ) diff --git a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake index 7866627555d030..d72fa2c18c07ca 100644 --- a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake +++ b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake @@ -22,8 +22,8 @@ include(OverridableFetchContent) OverridableFetchContent_Declare( cpuinfo GIT_REPOSITORY https://github.com/pytorch/cpuinfo - # Sync with tensorflow/third_party/cpuinfo/workspace.bzl - GIT_TAG 959002f82d7962a473d8bf301845f2af720e0aa4 + # Sync with tensorflow/workspace2.bzl + GIT_TAG ef634603954d88d2643d5809011288b890ac126e GIT_PROGRESS TRUE SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo" ) diff --git a/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake b/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake index ac296c0307f901..76d9705475b05b 100644 --- a/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake +++ b/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( gemmlowp GIT_REPOSITORY https://github.com/google/gemmlowp # Sync with tensorflow/third_party/gemmlowp/workspace.bzl - GIT_TAG e844ffd17118c1e17d94e1ba4354c075a4577b88 + GIT_TAG 16e8662c34917be0065110bfcd9cc27d30f52fdf # It's not currently (cmake 3.17) possible to shallow clone with a GIT TAG # as cmake attempts to git checkout the commit hash after the clone # which doesn't work as it's a shallow clone hence a different commit hash. diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index a6b36451cb819b..436be3901c4865 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG c7e7cde37615a81a529c326aa278bfab4cd6fe5a + GIT_TAG 0cbbe74a16e6ca11acf8484ccac85f620336dea4 GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/lite/tools/optimize/debugging/python/BUILD b/tensorflow/lite/tools/optimize/debugging/python/BUILD index e77895b569141f..529fdec107f5bd 100644 --- a/tensorflow/lite/tools/optimize/debugging/python/BUILD +++ b/tensorflow/lite/tools/optimize/debugging/python/BUILD @@ -29,6 +29,9 @@ py_strict_test( python_version = "PY3", deps = [ ":debugger", + "@absl_py//absl/testing:parameterized", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/lite/python:convert", "//tensorflow/lite/python:lite", @@ -36,7 +39,5 @@ py_strict_test( "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/trackable:autotrackable", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/lite/tools/optimize/python/BUILD b/tensorflow/lite/tools/optimize/python/BUILD index 2cba7d719c4d11..9c3527eb56b684 100644 --- a/tensorflow/lite/tools/optimize/python/BUILD +++ b/tensorflow/lite/tools/optimize/python/BUILD @@ -40,10 +40,11 @@ py_strict_test( srcs_version = "PY3", deps = [ ":modify_model_interface_lib", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow:tensorflow_py", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", ], ) diff --git a/tensorflow/lite/tools/optimize/sparsity/BUILD b/tensorflow/lite/tools/optimize/sparsity/BUILD index 13a95f0c517205..6a1a447e19e297 100644 --- a/tensorflow/lite/tools/optimize/sparsity/BUILD +++ b/tensorflow/lite/tools/optimize/sparsity/BUILD @@ -35,7 +35,8 @@ py_strict_test( python_version = "PY3", deps = [ ":format_converter_wrapper_pybind11", - "//third_party/py/numpy", "@absl_py//absl/testing:absltest", + #internal proto upb dep + "//third_party/py/numpy", ], ) diff --git a/tensorflow/lite/tools/signature/BUILD b/tensorflow/lite/tools/signature/BUILD index d418b826ce57ad..161dcd1554d04b 100644 --- a/tensorflow/lite/tools/signature/BUILD +++ b/tensorflow/lite/tools/signature/BUILD @@ -104,6 +104,7 @@ py_strict_test( visibility = ["//visibility:public"], deps = [ ":signature_def_utils", + #internal proto upb dep "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", ], diff --git a/tensorflow/lite/tools/tflite-android.Dockerfile b/tensorflow/lite/tools/tflite-android.Dockerfile index d1981b0224d2d4..3d84412ccb49dd 100644 --- a/tensorflow/lite/tools/tflite-android.Dockerfile +++ b/tensorflow/lite/tools/tflite-android.Dockerfile @@ -9,8 +9,8 @@ RUN apt-get update && \ # Install Android SDK. ENV ANDROID_SDK_FILENAME commandlinetools-linux-6858069_latest.zip ENV ANDROID_SDK_URL https://dl.google.com/android/repository/${ANDROID_SDK_FILENAME} -ENV ANDROID_API_LEVEL 23 -ENV ANDROID_NDK_API_LEVEL 21 +ENV ANDROID_API_LEVEL 30 +ENV ANDROID_NDK_API_LEVEL 30 # Build Tools Version liable to change. ENV ANDROID_BUILD_TOOLS_VERSION 31.0.0 ENV ANDROID_SDK_HOME ${ANDROID_DEV_HOME}/sdk @@ -23,7 +23,7 @@ RUN cd ${ANDROID_DEV_HOME} && \ rm ${ANDROID_SDK_FILENAME} # Install Android NDK. -ENV ANDROID_NDK_FILENAME android-ndk-r21e-linux-x86_64.zip +ENV ANDROID_NDK_FILENAME android-ndk-r25b-linux.zip ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME} ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk ENV PATH ${PATH}:${ANDROID_NDK_HOME} diff --git a/tensorflow/lite/tools/verifier.h b/tensorflow/lite/tools/verifier.h index 93bc5433c80e4e..f90d77b558fea0 100644 --- a/tensorflow/lite/tools/verifier.h +++ b/tensorflow/lite/tools/verifier.h @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_LITE_TOOLS_VERIFIER_H_ #define TENSORFLOW_LITE_TOOLS_VERIFIER_H_ -#include "tensorflow/lite/core/tools/verifier.h" +/// For documentation, see third_party/tensorflow/lite/core/tools/verifier.h + +#include "tensorflow/lite/core/tools/verifier.h" // IWYU pragma: export namespace tflite { diff --git a/tensorflow/lite/tools/verifier_internal.h b/tensorflow/lite/tools/verifier_internal.h index a3f499bc1fd10f..88380466877e50 100644 --- a/tensorflow/lite/tools/verifier_internal.h +++ b/tensorflow/lite/tools/verifier_internal.h @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_LITE_TOOLS_VERIFIER_INTERNAL_H_ #define TENSORFLOW_LITE_TOOLS_VERIFIER_INTERNAL_H_ -#include "tensorflow/lite/core/tools/verifier_internal.h" +/// For documentation, see third_party/tensorflow/lite/core/tools/verifier_internal.h + +#include "tensorflow/lite/core/tools/verifier_internal.h" // IWYU pragma: export namespace tflite { namespace internal { diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index e05b6419f863db..bba285328d2527 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -721,6 +721,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } case BuiltinOperator_GATHER_ND: + if (op_sig.inputs.at(0).type == kTfLiteBool) { + return 5; + } if (op_sig.inputs.at(1).type == kTfLiteInt16) { return 4; } diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 3d2e055894f978..5cff633f0ee0d0 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -1047,6 +1047,13 @@ TEST(OpVersionTest, VersioningGatherNdOperatorTest) { std::vector{kTfLiteInt32, kTfLiteInt16}), }; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig = { + .op = BuiltinOperator_GATHER_ND, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteBool, kTfLiteInt16}), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); } TEST(OpVersionTest, VersioningDivTest) { OpSignature fake_op_sig = { diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 47282cbf371e9a..d011a5d5438e46 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -143,6 +143,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_GATHER_ND, 2}, "2.3.0"}, {{BuiltinOperator_GATHER_ND, 3}, "2.5.0"}, {{BuiltinOperator_GATHER_ND, 4}, "2.13.0"}, + {{BuiltinOperator_GATHER_ND, 5}, "2.16.0"}, {{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"}, {{BuiltinOperator_SVDF, 1}, "1.5.0"}, {{BuiltinOperator_SVDF, 2}, "1.14.0"}, @@ -439,7 +440,8 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_STABLEHLO_MULTIPLY, 1}, "2.16.0"}, {{BuiltinOperator_STABLEHLO_REDUCE_WINDOW, 1}, "2.16.0"}, {{BuiltinOperator_STABLEHLO_MAXIMUM, 1}, "2.16.0"}, - {{BuiltinOperator_STABLEHLO_MINIMUM, 1}, "2.16.0"}}); + {{BuiltinOperator_STABLEHLO_MINIMUM, 1}, "2.16.0"}, + {{BuiltinOperator_STABLEHLO_PAD, 1}, "2.16.0"}}); std::pair version_key = {op_code, op_version}; auto it = op_version_map->find(version_key); diff --git a/tensorflow/lite/tutorials/BUILD b/tensorflow/lite/tutorials/BUILD index 9c34628d29418b..77c275d74651b6 100644 --- a/tensorflow/lite/tutorials/BUILD +++ b/tensorflow/lite/tutorials/BUILD @@ -1,6 +1,6 @@ # Example Estimator model -load("//tensorflow:strict.default.bzl", "py_strict_binary") +load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -10,11 +10,18 @@ package( py_strict_binary( name = "mnist_tflite", - srcs = [ - "dataset.py", - "mnist_tflite.py", - ], + srcs = ["mnist_tflite.py"], python_version = "PY3", + deps = [ + ":dataset", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +py_strict_library( + name = "dataset", + srcs = ["dataset.py"], deps = [ "//tensorflow:tensorflow_py", "//third_party/py/numpy", diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 68f19f8d488a63..0a3015106ef946 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -175,6 +175,8 @@ tf_staging/tensorflow/tools/toolchains/BUILD: tf_staging/tensorflow/tools/toolchains/clang6/BUILD: tf_staging/tensorflow/tools/toolchains/cpus/py/BUILD: tf_staging/tensorflow/tools/toolchains/cpus/py3/BUILD: +tf_staging/tensorflow/tools/toolchains/cross_compile/cc/BUILD: +tf_staging/tensorflow/tools/toolchains/cross_compile/config/BUILD: tf_staging/tensorflow/tools/toolchains/embedded/arm-linux/BUILD: tf_staging/tensorflow/tools/toolchains/java/BUILD: tf_staging/tensorflow/tools/toolchains/python/BUILD: @@ -236,7 +238,9 @@ tf_staging/third_party/gpus/crosstool/BUILD: tf_staging/third_party/gpus/crosstool/LICENSE: tf_staging/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl: tf_staging/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl: +tf_staging/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl: tf_staging/third_party/gpus/cuda/BUILD.tpl: +tf_staging/third_party/gpus/cuda/BUILD.windows.tpl: tf_staging/third_party/gpus/cuda/BUILD: tf_staging/third_party/gpus/cuda/LICENSE: tf_staging/third_party/gpus/cuda/build_defs.bzl.tpl: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a14198ec061347..d4a4799c0af207 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -83,7 +83,6 @@ py_strict_library( "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:stateful_random_ops", "//tensorflow/python/ops/structured:structured_ops", - "//tensorflow/python/tpu:tpu_estimator", "//tensorflow/python/tpu:tpu_noestimator", ], ) @@ -171,6 +170,7 @@ py_strict_library( "//tensorflow/python/framework:_test_metrics_util", "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:config", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:extension_type", "//tensorflow/python/framework:flexible_dtypes", @@ -350,6 +350,8 @@ py_strict_library( ":tf2", "//tensorflow/core:protos_all_py", "//tensorflow/core/function/trace_type", + "//tensorflow/python/checkpoint/sharding:sharding_policies", + "//tensorflow/python/checkpoint/sharding:sharding_util", "//tensorflow/python/client", "//tensorflow/python/client:device_lib", "//tensorflow/python/client:timeline", @@ -392,6 +394,7 @@ py_strict_library( "//tensorflow/python/lib/io:python_io", "//tensorflow/python/lib/io:tf_record", "//tensorflow/python/module", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:audio_ops_gen", "//tensorflow/python/ops:bincount_ops", "//tensorflow/python/ops:bitwise_ops", @@ -448,6 +451,7 @@ py_strict_library( "//tensorflow/python/profiler:trace", "//tensorflow/python/saved_model", "//tensorflow/python/summary:summary_py", + "//tensorflow/python/summary:tb_summary", "//tensorflow/python/tpu:tpu_noestimator", "//tensorflow/python/training", "//tensorflow/python/training:quantize_training", @@ -463,15 +467,6 @@ py_strict_library( ], ) -# Necessary for the pywrap inclusion below. -tf_pybind_cc_library_wrapper( - name = "tfcompile_headers_lib", - compatible_with = [], - deps = [ - "//tensorflow/compiler/aot:tfcompile_lib", - ], -) - tf_python_pybind_extension( name = "_pywrap_tfcompile", srcs = ["tfcompile_wrapper.cc"], @@ -481,15 +476,13 @@ tf_python_pybind_extension( "//tensorflow:windows": [], }), enable_stub_generation = True, - features = ["-layering_check"], pytype_srcs = [ "_pywrap_tfcompile.pyi", ], static_deps = tf_python_pybind_static_deps(), deps = [ - ":tfcompile_headers_lib", "@pybind11", - "//third_party/python_runtime:headers", + "//tensorflow/compiler/aot:tfcompile_lib", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status", # The headers here cannot be brought in via cc_header_only_library @@ -776,7 +769,6 @@ pywrap_tensorflow_macro( "//tensorflow/cc/saved_model:fingerprinting_impl", "//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:metrics_impl", - "//tensorflow/compiler/mlir/python:mlir", "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", @@ -848,7 +840,12 @@ pywrap_tensorflow_macro( "@local_tsl//tsl/profiler/rpc/client:profiler_client_impl", "@local_tsl//tsl/python/lib/core:numpy", "@local_xla//xla/stream_executor:stream_executor_impl", - ] + if_static([ + ] + select({ + "//tensorflow/compiler/mlir/python:disable_mlir_config": [], + "//conditions:default": [ + "//tensorflow/compiler/mlir/python:mlir", + ], + }) + if_static([ "//tensorflow/core/platform:tensor_float_32_utils", "//tensorflow/core/platform:enable_tf2_utils", ]) + if_google([ @@ -886,7 +883,6 @@ filegroup( "//tensorflow/cc/saved_model:metrics_impl", # SavedModel metrics "//tensorflow/compiler/jit:flags", # tfe "//tensorflow/compiler/jit:get_compiler_ir", # tfe - "//tensorflow/compiler/mlir/python:mlir", # mlir "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", # quantization "//tensorflow/compiler/tf2xla:tf2xla_opset", # pywrap_xla_ops "//tensorflow/core:framework_internal_impl", # op_def_registry @@ -961,7 +957,12 @@ filegroup( "@local_tsl//tsl/python/lib/core:ml_dtypes_lib", # bfloat16, float8_e4m3fn, float8_e5m2 "@local_tsl//tsl/python/lib/core:numpy", # checkpoint_reader "@local_xla//xla/stream_executor", # stat_summarizer - ] + if_xla_available([ + ] + select({ + "//tensorflow/compiler/mlir/python:disable_mlir_config": [], + "//conditions:default": [ + "//tensorflow/compiler/mlir/python:mlir", # mlir + ], + }) + if_xla_available([ "//tensorflow/compiler/aot:tfcompile_lib", # tfcompile "@local_xla//xla:status_macros", # tfcompile "@local_xla//xla/hlo/ir:hlo", # tfcompile diff --git a/tensorflow/python/_pywrap_tfe.pyi b/tensorflow/python/_pywrap_tfe.pyi index 26d129cd2a8566..1385ae69244d58 100644 --- a/tensorflow/python/_pywrap_tfe.pyi +++ b/tensorflow/python/_pywrap_tfe.pyi @@ -179,6 +179,7 @@ def TFE_ClearScalarCache() -> object: ... def TFE_CollectiveOpsCheckPeerHealth(arg0: object, arg1: str, arg2: int) -> None: ... def TFE_ContextAddFunction(arg0: object, arg1: TF_Function) -> None: ... def TFE_ContextAddFunctionDef(arg0: object, arg1: str, arg2: int) -> None: ... +def TFE_ContextAddFunctionDefNoSerialization(ctx: object, function_def) -> None: ... def TFE_ContextCheckAlive(arg0: object, arg1: str) -> bool: ... def TFE_ContextClearCaches(arg0: object) -> None: ... def TFE_ContextClearExecutors(arg0: object) -> None: ... diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 5624f7611f3c84..82177fb9002207 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -200,6 +200,7 @@ py_strict_test( ":asserts", ":functions", ":return_statements", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", @@ -214,6 +215,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":break_statements", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/platform:client_testlib", @@ -228,6 +230,7 @@ py_strict_test( deps = [ ":call_trees", ":functions", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/platform:client_testlib", ], @@ -240,6 +243,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":conditional_expressions", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/platform:client_testlib", ], @@ -252,6 +256,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":continue_statements", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/framework:ops", "//tensorflow/python/platform:client_testlib", @@ -267,6 +272,8 @@ py_strict_test( ":break_statements", ":continue_statements", ":control_flow", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", @@ -276,7 +283,6 @@ py_strict_test( "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/util:nest", - "//third_party/py/numpy", ], ) @@ -287,6 +293,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":directives", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/lang:directives", "//tensorflow/python/autograph/pyct:anno", @@ -301,6 +308,7 @@ py_strict_test( deps = [ ":functions", ":return_statements", + #internal proto upb dep "//tensorflow/python/autograph/core:ag_ctx", "//tensorflow/python/autograph/core:converter", "//tensorflow/python/autograph/core:test_lib", @@ -318,6 +326,7 @@ py_strict_test( deps = [ ":directives", ":lists", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/lang:directives", "//tensorflow/python/autograph/lang:special_functions", @@ -336,6 +345,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":logical_expressions", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:test_lib", @@ -351,6 +361,7 @@ py_strict_test( deps = [ ":functions", ":return_statements", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/framework:ops", "//tensorflow/python/platform:client_testlib", @@ -365,6 +376,7 @@ py_strict_test( deps = [ ":directives", ":slices", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/autograph/lang:directives", "//tensorflow/python/framework:constant_op", @@ -381,6 +393,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":variables", + #internal proto upb dep "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/platform:client_testlib", ], diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index 46983ab39f0a2b..d1d4ee16fe1761 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -91,6 +91,7 @@ py_strict_test( deps = [ ":converter", ":test_lib", + #internal proto upb dep "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/autograph/pyct:loader", "//tensorflow/python/autograph/pyct:parser", @@ -107,6 +108,7 @@ py_strict_test( deps = [ ":converter", ":function_wrappers", + #internal proto upb dep "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/ops:variables", diff --git a/tensorflow/python/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD index d9207ac75a1b87..f857454188571f 100644 --- a/tensorflow/python/autograph/lang/BUILD +++ b/tensorflow/python/autograph/lang/BUILD @@ -31,12 +31,13 @@ py_strict_test( python_version = "PY3", srcs_version = "PY3", deps = [ + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow/python/autograph/lang:special_functions", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:list_ops", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index d3ab48bbf2c245..25dd28737fce2e 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -159,6 +159,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":data_structures", + #internal proto upb dep "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:tensor", @@ -176,6 +177,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":conditional_expressions", + #internal proto upb dep "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:test_lib", @@ -194,6 +196,8 @@ py_strict_test( deps = [ ":control_flow", ":variables", + #internal proto upb dep + "//third_party/py/numpy", "//tensorflow/python/autograph/utils:ag_logging", "//tensorflow/python/autograph/utils:testing", "//tensorflow/python/data/ops:dataset_ops", @@ -209,7 +213,6 @@ py_strict_test( "//tensorflow/python/ops:random_ops", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", ], ) @@ -220,6 +223,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":exceptions", + #internal proto upb dep "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:test_lib", @@ -234,6 +238,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":logical", + #internal proto upb dep "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", @@ -248,6 +253,7 @@ py_strict_test( deps = [ ":data_structures", ":py_builtins", + #internal proto upb dep "//tensorflow/python/autograph/core:converter", "//tensorflow/python/autograph/core:function_wrappers", "//tensorflow/python/data/ops:dataset_ops", @@ -270,6 +276,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":slices", + #internal proto upb dep "//tensorflow/python/framework:constant_op", "//tensorflow/python/ops:list_ops", "//tensorflow/python/platform:client_testlib", @@ -283,6 +290,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":variables", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 949d841e00cc49..442823158b5f8e 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -193,6 +193,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":anno", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) @@ -209,8 +210,9 @@ py_strict_test( ":parser", ":pretty_printer", ":qual_names", - "//tensorflow/python/platform:client_testlib", "@pypi_gast//:pkg", + #internal proto upb dep + "//tensorflow/python/platform:client_testlib", ], ) @@ -221,6 +223,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":cache", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) @@ -233,8 +236,9 @@ py_strict_test( deps = [ ":cfg", ":parser", - "//tensorflow/python/platform:client_testlib", "@pypi_gast//:pkg", + #internal proto upb dep + "//tensorflow/python/platform:client_testlib", ], ) @@ -248,9 +252,10 @@ py_strict_test( ":loader", ":parser", ":pretty_printer", + "@pypi_gast//:pkg", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", "//tensorflow/python/util:tf_inspect", - "@pypi_gast//:pkg", ], ) @@ -262,6 +267,7 @@ py_strict_test( deps = [ ":error_utils", ":origin_info", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) @@ -273,6 +279,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":inspect_utils", + #internal proto upb dep "//tensorflow/python/autograph/pyct/testing:basic_definitions", "//tensorflow/python/autograph/pyct/testing:decorators", "//tensorflow/python/framework:constant_op", @@ -294,6 +301,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":naming", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) @@ -308,6 +316,7 @@ py_strict_test( ":inspect_utils", ":origin_info", ":parser", + #internal proto upb dep "//tensorflow/python/autograph/pyct/testing:basic_definitions", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/util:tf_inspect", @@ -324,8 +333,9 @@ py_strict_test( ":errors", ":parser", ":pretty_printer", - "//tensorflow/python/platform:client_testlib", "@pypi_gast//:pkg", + #internal proto upb dep + "//tensorflow/python/platform:client_testlib", ], ) @@ -336,6 +346,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":pretty_printer", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) @@ -349,6 +360,7 @@ py_strict_test( ":anno", ":parser", ":qual_names", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", ], ) @@ -363,9 +375,10 @@ py_strict_test( ":parser", ":qual_names", ":templates", - "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", "@pypi_gast//:pkg", + #internal proto upb dep + "//tensorflow/python/platform:client_testlib", ], ) @@ -379,8 +392,9 @@ py_strict_test( ":origin_info", ":parser", ":transformer", - "//tensorflow/python/platform:client_testlib", "@pypi_gast//:pkg", + #internal proto upb dep + "//tensorflow/python/platform:client_testlib", ], ) @@ -392,7 +406,8 @@ py_strict_test( deps = [ ":transformer", ":transpiler", - "//tensorflow/python/platform:client_testlib", "@pypi_gast//:pkg", + #internal proto upb dep + "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index fd8ddf046d29e9..3c4f0ac15919e6 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -780,6 +780,11 @@ def visit_ImportFrom(self, node): def visit_Expr(self, node): self._process_basic_statement(node) + def visit_NamedExpr(self, node): + # TODO(yileiyang): Add a test case once we have a newer astunparse version. + # NamedExpr was introduced in Python 3.8 and supported in gast 0.5.1+. + self._process_basic_statement(node) + def visit_Assign(self, node): self._process_basic_statement(node) diff --git a/tensorflow/python/autograph/pyct/common_transformers/BUILD b/tensorflow/python/autograph/pyct/common_transformers/BUILD index 2be00498cf7d4d..44160a7f3f22f2 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/BUILD +++ b/tensorflow/python/autograph/pyct/common_transformers/BUILD @@ -28,10 +28,11 @@ py_strict_test( tags = ["no_oss"], deps = [ ":common_transformers", + "@pypi_gast//:pkg", + #internal proto upb dep "//tensorflow/python/autograph/pyct:loader", "//tensorflow/python/autograph/pyct:parser", "//tensorflow/python/autograph/pyct:transformer", "//tensorflow/python/platform:client_testlib", - "@pypi_gast//:pkg", ], ) diff --git a/tensorflow/python/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD index 4329523b0562de..7e5011fa2d9c16 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD @@ -37,6 +37,7 @@ py_strict_test( ":activity", ":reaching_definitions", ":reaching_fndefs", + #internal proto upb dep "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/autograph/pyct:cfg", "//tensorflow/python/autograph/pyct:naming", @@ -101,13 +102,14 @@ py_strict_test( deps = [ ":activity", ":annos", + "@pypi_gast//:pkg", + #internal proto upb dep "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/autograph/pyct:naming", "//tensorflow/python/autograph/pyct:parser", "//tensorflow/python/autograph/pyct:qual_names", "//tensorflow/python/autograph/pyct:transformer", "//tensorflow/python/platform:client_testlib", - "@pypi_gast//:pkg", ], ) @@ -121,6 +123,7 @@ py_strict_test( ":activity", ":liveness", ":reaching_fndefs", + #internal proto upb dep "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/autograph/pyct:cfg", "//tensorflow/python/autograph/pyct:naming", @@ -139,6 +142,7 @@ py_strict_test( deps = [ ":activity", ":reaching_definitions", + #internal proto upb dep "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/autograph/pyct:cfg", "//tensorflow/python/autograph/pyct:naming", @@ -159,6 +163,7 @@ py_strict_test( ":reaching_definitions", ":reaching_fndefs", ":type_inference", + #internal proto upb dep "//tensorflow/python/autograph/pyct:anno", "//tensorflow/python/autograph/pyct:cfg", "//tensorflow/python/autograph/pyct:qual_names", diff --git a/tensorflow/python/autograph/pyct/testing/BUILD b/tensorflow/python/autograph/pyct/testing/BUILD index 21a6775b0fb539..51d186363ebb2d 100644 --- a/tensorflow/python/autograph/pyct/testing/BUILD +++ b/tensorflow/python/autograph/pyct/testing/BUILD @@ -45,7 +45,8 @@ py_strict_test( ], deps = [ ":codegen", - "//tensorflow/python/platform:client_testlib", + #internal proto upb dep "//third_party/py/numpy", + "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/python/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD index d758c28801c315..f5aad03ed8fd8c 100644 --- a/tensorflow/python/autograph/utils/BUILD +++ b/tensorflow/python/autograph/utils/BUILD @@ -101,6 +101,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":context_managers", + #internal proto upb dep "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:tensor_array_ops", @@ -115,6 +116,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":misc", + #internal proto upb dep "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:test_lib", @@ -130,6 +132,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":tensor_list", + #internal proto upb dep "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", @@ -148,6 +151,7 @@ py_strict_test( srcs_version = "PY3", deps = [ ":tensors", + #internal proto upb dep "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:list_ops", diff --git a/tensorflow/python/checkpoint/BUILD b/tensorflow/python/checkpoint/BUILD index 0c9d5c696b4c20..11c2986f8be6a8 100644 --- a/tensorflow/python/checkpoint/BUILD +++ b/tensorflow/python/checkpoint/BUILD @@ -5,6 +5,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") load( "//tensorflow/tools/test:performance.bzl", + "tf_py_benchmark_test", "tf_py_logged_benchmark", ) @@ -99,8 +100,10 @@ py_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:object_identity", - "//tensorflow/python/util:tf_decorator", + "//tensorflow/python/util:tf_contextlib", "//tensorflow/python/util:tf_export", + "//tensorflow/python/util:tf_inspect", + "@absl_py//absl/logging", ], ) @@ -363,7 +366,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "benchmarks_test", srcs = ["benchmarks_test.py"], deps = [ @@ -389,6 +392,7 @@ py_strict_library( srcs = ["checkpoint_options.py"], srcs_version = "PY3", deps = [ + "//tensorflow/python/checkpoint/sharding:sharding_util", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", ], @@ -401,6 +405,8 @@ py_strict_library( deps = [ ":checkpoint_options", "//tensorflow/core:protos_all_py", + "//tensorflow/python/checkpoint/sharding:sharding_policies", + "//tensorflow/python/checkpoint/sharding:sharding_util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", @@ -408,19 +414,18 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", - "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:io_ops", "//tensorflow/python/ops:io_ops_gen", "//tensorflow/python/ops:string_ops", - "//tensorflow/python/ops:variables", "//tensorflow/python/saved_model/registration", "//tensorflow/python/trackable:base", "//tensorflow/python/trackable:trackable_utils", "//tensorflow/python/training/saving:saveable_object", "//tensorflow/python/training/saving:saveable_object_util", + "//tensorflow/python/types:core", "//tensorflow/python/util:nest", "//tensorflow/python/util:object_identity", ], @@ -446,10 +451,10 @@ cuda_py_strict_test( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/module", + "//tensorflow/python/ops:io_ops_gen", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/platform:gfile", "//tensorflow/python/training:server_lib", - "//tensorflow/python/training/saving:saveable_object", "//tensorflow/python/training/saving:saveable_object_util", ], ) diff --git a/tensorflow/python/checkpoint/checkpoint_options.py b/tensorflow/python/checkpoint/checkpoint_options.py index 662fdcc455c4a3..7a081b80377ce9 100644 --- a/tensorflow/python/checkpoint/checkpoint_options.py +++ b/tensorflow/python/checkpoint/checkpoint_options.py @@ -17,6 +17,7 @@ import copy import inspect +from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -45,6 +46,7 @@ class CheckpointOptions(object): "experimental_enable_async_checkpoint", "experimental_write_callbacks", "enable_async", + "experimental_sharding_callback", ) @deprecated_args( @@ -56,6 +58,7 @@ def __init__( experimental_enable_async_checkpoint=False, experimental_write_callbacks=None, enable_async=False, + experimental_sharding_callback=None, ): """Creates an object that stores options for a Checkpoint. @@ -91,6 +94,13 @@ def __init__( writing runs in the background. Async checkpoint reduces TPU device idle cycles and speeds up model training process, while memory consumption may increase. + + experimental_sharding_callback: `tf.train.experimental.ShardingCallback`. + A pre-made or custom callback that determines how checkpoints are + sharded on disk. Pre-made callback options are + `tf.train.experimental.ShardByDevicePolicy` and + `tf.train.experimental.MaxShardSizePolicy`. You may also write a custom + callback, see `tf.train.experimental.ShardingCallback`. """ self.experimental_io_device = experimental_io_device self.enable_async = experimental_enable_async_checkpoint or enable_async @@ -100,6 +110,13 @@ def __init__( for callback in experimental_write_callbacks: assert len(inspect.signature(callback).parameters) <= 1 self.experimental_write_callbacks = experimental_write_callbacks + if experimental_sharding_callback is not None: + if not isinstance( + experimental_sharding_callback, sharding_util.ShardingCallback): + raise ValueError("The experimental_sharding_callback checkpoint option" + "must be of type ShardingCallback. The option provided" + f"was of type {type(experimental_sharding_callback)}.") + self.experimental_sharding_callback = experimental_sharding_callback def __copy__(self): # Only `experimental_write_callbacks` needs special treatment to Ensure that diff --git a/tensorflow/python/checkpoint/functional_saver.py b/tensorflow/python/checkpoint/functional_saver.py index bd2868013ed3ef..6c918d3bd969a6 100644 --- a/tensorflow/python/checkpoint/functional_saver.py +++ b/tensorflow/python/checkpoint/functional_saver.py @@ -15,10 +15,12 @@ """Saves and restore variables inside traced @tf.functions.""" import dataclasses -from typing import Callable, Dict, List +from typing import Callable, Mapping, MutableMapping, MutableSequence, Sequence from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.checkpoint import checkpoint_options +from tensorflow.python.checkpoint.sharding import sharding_policies +from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -26,181 +28,125 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import string_ops -from tensorflow.python.ops import variables from tensorflow.python.saved_model import registration from tensorflow.python.trackable import base from tensorflow.python.trackable import trackable_utils from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util +from tensorflow.python.types import core from tensorflow.python.util import nest from tensorflow.python.util import object_identity -@dataclasses.dataclass(frozen=True) -class ShardableTensor: - """Tensor wrapper containing data necessary for sharding.""" - _tensor_save_spec: saveable_object.SaveSpec - tensor: tensor_lib.Tensor - dtype: dtypes.DType - device: device_lib.DeviceSpec - name: str - shape: tensor_shape.TensorShape - slice_spec: variables.Variable.SaveSliceInfo - checkpoint_key: str - trackable: base.Trackable - - def __hash__(self): - return hash((self.name, self.dtype, str(self.device), self.checkpoint_key)) - - -@dataclasses.dataclass(frozen=True) -class ShardingCallback: - """Checkpoint sharding callback function, along with a text description.""" - callback: Callable[ - [List[ShardableTensor], ...], - List[Dict[str, Dict[tensor_spec.TensorSpec, saveable_object.SaveSpec]]]] - description: str - - def __hash__(self): - if hasattr(self.callback, "__name__"): - callback_hash = hash((self.callback.__module__, self.callback.__name__)) - else: - callback_hash = id(self.callback) - return hash((callback_hash, self.description)) - - -class ShardByDevicePolicy(ShardingCallback): - """Policy that splits tensors into shards based on their device spec.""" - - def __init__(self): - def device_callback_impl(shardable_tensors): - """Callback to split tensors into shards based on their device spec. - - Args: - shardable_tensors: A list of ShardableTensors. - - Returns: - List of shard dicts containing tensors. - [ {checkpoint key: {slice_spec: tensor} } ] - """ - tensors_by_device = {} - - for shardable_tensor in shardable_tensors: - tensor = shardable_tensor.tensor - checkpoint_key = shardable_tensor.checkpoint_key - slice_spec = shardable_tensor.slice_spec - device = saveable_object_util.set_cpu0(shardable_tensor.device) +RegisteredSaversDict = Mapping[ + registration.RegisteredSaver, Mapping[str, base.Trackable]] +MappedCapturesCallable = Callable[ + [core.ConcreteFunction, Sequence[tensor_lib.Tensor]], tensor_lib.Tensor] - (tensors_by_device - .setdefault(device, {}) - .setdefault(checkpoint_key, {})[slice_spec]) = tensor - return list(tensors_by_device.values()) +def _single_shard_save( + file_prefix: tensor_lib.Tensor, + shard: sharding_util.TensorSliceDict, + task: device_lib.DeviceSpec, + options: "checkpoint_options.CheckpointOptions | None" = None, +) -> ops.Operation: + """Save the saveable objects to a checkpoint with `file_prefix`. - super().__init__( - device_callback_impl, - "Split tensors into shards based on their device spec.") - - def __call__(self, shardable_tensors): - return self.callback(shardable_tensors) # pylint: disable=no-value-for-parameter - - -class _SingleDeviceSaver(object): - """Saves and restores checkpoints from the current device.""" - - __slots__ = ["_tensor_slice_dict"] - - def __init__(self, tensor_slice_dict): - """Specify a list of `SaveableObject`s to save and restore. - - Args: - tensor_slice_dict: A dict mapping checkpoint key -> slice_spec -> tensor. - """ - self._tensor_slice_dict = tensor_slice_dict - - def save(self, file_prefix, options=None): - """Save the saveable objects to a checkpoint with `file_prefix`. + Args: + file_prefix: A string or scalar string Tensor containing the prefix to + save under. + shard: Dict containing tensors. {checkpoint key: {slice_spec: tensor} } + task: The device spec task of the tensors in the shard. + options: Optional `CheckpointOptions` object. - Args: - file_prefix: A string or scalar string Tensor containing the prefix to - save under. - options: Optional `CheckpointOptions` object. - Returns: - An `Operation`, or None when executing eagerly. - """ - options = options or checkpoint_options.CheckpointOptions() - tensor_names = [] - tensors = [] - slice_specs = [] - for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): - for slice_spec, tensor in tensor_slices.items(): - if isinstance(tensor, saveable_object.SaveSpec): - tensor_value = tensor.tensor - # A tensor value of `None` indicates that this SaveableObject gets - # recorded in the object graph, but that no value is saved in the - # checkpoint. - if tensor_value is not None: - tensor_names.append(tensor.name) - tensors.append(tensor_value) - slice_specs.append(tensor.slice_spec) - else: - tensor_names.append(checkpoint_key) - tensors.append(tensor) - slice_specs.append(slice_spec) - save_device = options.experimental_io_device or ( - len(tensors) and saveable_object_util.set_cpu0(tensors[0].device)) - save_device = save_device or "cpu:0" - with ops.device(save_device): - return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors) - - def restore(self, file_prefix, options=None): - """Restore the saveable objects from a checkpoint with `file_prefix`. + Returns: + An `Operation`, or None when executing eagerly. + """ + options = options or checkpoint_options.CheckpointOptions() + + tensor_names = [] + tensors = [] + slice_specs = [] + for checkpoint_key, tensor_slices in shard.items(): + for slice_spec, tensor in tensor_slices.items(): + # A tensor value of `None` indicates that this SaveableObject gets + # recorded in the object graph, but that no value is saved in the + # checkpoint. + if tensor is not None: + # See `MultiDeviceSaver._get_shards_by_task` for an explanation on the + # wrapped properties. + name = (tensor._wrapped_name # pylint: disable=protected-access + if hasattr(tensor, "_wrapped_name") + else checkpoint_key) + spec = (tensor._wrapped_slice_spec # pylint: disable=protected-access + if hasattr(tensor, "_wrapped_slice_spec") + else slice_spec) + + tensor_names.append(name) + tensors.append(tensor) + slice_specs.append(spec) + + save_device = options.experimental_io_device or (len(tensors) and task) + with ops.device(save_device or "CPU:0"): + return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors) + + +def _single_shard_restore( + file_prefix: tensor_lib.Tensor, + shardable_tensors: Sequence[sharding_util.ShardableTensor], + options: "checkpoint_options.CheckpointOptions | None" = None +) -> sharding_util.TensorSliceDict: + """Restore the saveable objects from a checkpoint with `file_prefix`. - Args: - file_prefix: A string or scalar string Tensor containing the prefix for - files to read from. - options: Optional `CheckpointOptions` object. + Args: + file_prefix: A string or scalar string Tensor containing the prefix for + files to read from. + shardable_tensors: A list of ShardableTensors to restore. + options: Optional `CheckpointOptions` object. - Returns: - A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). - """ - options = options or checkpoint_options.CheckpointOptions() - tensor_names = [] - tensor_dtypes = [] - slice_specs = [] - - for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): - for slice_spec, tensor in tensor_slices.items(): - tensor_dtypes.append(tensor.dtype) - if isinstance(tensor, saveable_object.SaveSpec): - slice_specs.append(tensor.slice_spec) - tensor_names.append(tensor.name) - else: - slice_specs.append(slice_spec) - tensor_names.append(checkpoint_key) - - restore_device = options.experimental_io_device or "cpu:0" - with ops.device(restore_device): - restored_tensors = io_ops.restore_v2( - file_prefix, tensor_names, slice_specs, tensor_dtypes) - - restored_tensor_dict = {} - for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): - for slice_spec in tensor_slices: - restored_tensor = restored_tensors.pop(0) - restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = ( - restored_tensor) - return restored_tensor_dict - - -def sharded_filename(filename_tensor, shard, num_shards): + Returns: + A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). + """ + options = options or checkpoint_options.CheckpointOptions() + + tensor_names = [] + tensor_dtypes = [] + slice_specs = [] + for shardable_tensor in shardable_tensors: + if shardable_tensor._tensor_save_spec: # pylint: disable=protected-access + name = shardable_tensor._tensor_save_spec.name # pylint: disable=protected-access + spec = shardable_tensor._tensor_save_spec.slice_spec # pylint: disable=protected-access + else: + name, spec = shardable_tensor.checkpoint_key, shardable_tensor.slice_spec + tensor_names.append(name) + slice_specs.append(spec) + tensor_dtypes.append(shardable_tensor.dtype) + + restore_device = options.experimental_io_device or "cpu:0" + with ops.device(restore_device): + restored_tensors = io_ops.restore_v2( + file_prefix, tensor_names, slice_specs, tensor_dtypes) + + restored_tensor_dict = {} + for shardable_tensor in shardable_tensors: + restored_tensor = restored_tensors.pop(0) + (restored_tensor_dict + .setdefault(shardable_tensor.checkpoint_key, {} + )[shardable_tensor.slice_spec]) = restored_tensor + return restored_tensor_dict + + +def sharded_filename( + filename_tensor: tensor_lib.Tensor, + shard: int, + num_shards: tensor_lib.Tensor +) -> tensor_lib.Tensor: """Append sharding information to a filename. Args: @@ -214,15 +160,22 @@ def sharded_filename(filename_tensor, shard, num_shards): return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) -def registered_saver_filename(filename_tensor, saver_name): +def registered_saver_filename( + filename_tensor: tensor_lib.Tensor, + saver_name: registration.RegisteredSaver +) -> tensor_lib.Tensor: return string_ops.string_join( [filename_tensor, constant_op.constant(f"-{saver_name}")]) -def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures): +def _get_mapped_registered_save_fn( + fn: Callable[..., tensor_lib.Tensor], + trackables: Sequence[base.Trackable], + call_with_mapped_captures: MappedCapturesCallable +) -> Callable[[tensor_lib.Tensor], MappedCapturesCallable]: """Converts the function to a python or tf.function with a single file arg.""" - def save_fn(file_prefix): + def save_fn(file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return fn(trackables=trackables, file_prefix=file_prefix) if call_with_mapped_captures is None: return save_fn @@ -231,17 +184,21 @@ def save_fn(file_prefix): concrete = tf_fn.get_concrete_function( file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) - def save_fn_with_replaced_captures(file_prefix): + def save_fn_with_replaced_captures( + file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return call_with_mapped_captures(concrete, [file_prefix]) return save_fn_with_replaced_captures -def _get_mapped_registered_restore_fn(fn, trackables, - call_with_mapped_captures): +def _get_mapped_registered_restore_fn( + fn: Callable[..., tensor_lib.Tensor], + trackables: Sequence[base.Trackable], + call_with_mapped_captures: MappedCapturesCallable +) -> Callable[..., tensor_lib.Tensor]: """Converts the function to a python or tf.function with a single file arg.""" - def restore_fn(merged_prefix): + def restore_fn(merged_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return fn(trackables=trackables, merged_prefix=merged_prefix) if call_with_mapped_captures is None: return restore_fn @@ -250,7 +207,8 @@ def restore_fn(merged_prefix): concrete = tf_fn.get_concrete_function( merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) - def restore_fn_with_replaced_captures(merged_prefix): + def restore_fn_with_replaced_captures( + merged_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: return call_with_mapped_captures(concrete, [merged_prefix]) return restore_fn_with_replaced_captures @@ -259,7 +217,7 @@ def restore_fn_with_replaced_captures(merged_prefix): _restore_noop = lambda *args, **kwargs: None -class MultiDeviceSaver(object): +class MultiDeviceSaver: """Saves checkpoints directly from multiple devices. Note that this is a low-level utility which stores Tensors in the keys @@ -267,10 +225,12 @@ class MultiDeviceSaver(object): checkpointing are built on top of it. """ - def __init__(self, - serialized_tensors, - registered_savers=None, - call_with_mapped_captures=None): + def __init__( + self, + serialized_tensors: Mapping[ + base.Trackable, sharding_util.TensorSliceDict], + registered_savers: "RegisteredSaversDict | None" = None, + call_with_mapped_captures: "MappedCapturesCallable | None" = None): """Specify a list of `SaveableObject`s to save and restore. Args: @@ -284,24 +244,37 @@ def __init__(self, Trackable in the checkpoint. call_with_mapped_captures: TODO """ + self._shardable_tensors: MutableSequence[sharding_util.ShardableTensor] = [] # Keep these two data structures so that we can map restored tensors to # the Trackable restore functions. - self._keys_to_restore_fn = {} - self._restore_fn_to_keys = {} - - # Extract serialized tensors and separate by device. - tensors_by_device = {} # device -> checkpoint key -> (slice_spec ->) tensor - + self._keys_to_restore_fn: MutableMapping[ + sharding_util.TensorSlice, + Callable[Mapping[str, tensor_lib.Tensor]]] = {} + self._restore_fn_to_keys: MutableMapping[ + Callable[Mapping[str, tensor_lib.Tensor]], + MutableSequence[sharding_util.TensorSlice]] = {} + + unique_tasks = set() for obj, tensor_dict in serialized_tensors.items(): restore_fn = _restore_noop if obj is None else obj._restore_from_tensors - # Divide tensor_dict by device. - for checkpoint_key, maybe_tensor in tensor_dict.items(): - if not isinstance(maybe_tensor, dict): + # Divide tensor_dict by task. + for checkpoint_key, tensor_slice_dict in tensor_dict.items(): + if not isinstance(tensor_slice_dict, dict): # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. - maybe_tensor = {"": maybe_tensor} + tensor_slice_dict = {"": tensor_slice_dict} + + for slice_spec, tensor_save_spec in tensor_slice_dict.items(): + tensor_value = None + if not isinstance(tensor_save_spec, saveable_object.SaveSpec): + tensor_value = tensor_save_spec + tensor_save_spec = saveable_object.SaveSpec( + tensor=tensor_value, + slice_spec=slice_spec, + name=checkpoint_key, + dtype=tensor_save_spec.dtype, + device=tensor_save_spec.device) - for slice_spec, tensor in maybe_tensor.items(): if (checkpoint_key, slice_spec) in self._keys_to_restore_fn: raise ValueError( "Recieved multiple tensors with the same checkpoint key and " @@ -312,13 +285,24 @@ def __init__(self, self._restore_fn_to_keys.setdefault(restore_fn, []).append( (checkpoint_key, slice_spec)) - host_device = saveable_object_util.set_cpu0(tensor.device) - (tensors_by_device - .setdefault(host_device, {}) - .setdefault(checkpoint_key, {})[slice_spec]) = tensor - self._single_device_savers = { - device: _SingleDeviceSaver(tensor_slice_dict) - for device, tensor_slice_dict in tensors_by_device.items()} + device = (device_lib.DeviceSpec.from_string(tensor_save_spec.device) + if isinstance(tensor_save_spec.device, str) + else tensor_save_spec.device) + self._shardable_tensors.append( + sharding_util.ShardableTensor( + _tensor_save_spec=tensor_save_spec, + tensor=tensor_value, + dtype=tensor_save_spec.dtype, + device=device, + name=tensor_save_spec.name, + shape=None, + slice_spec=slice_spec.strip(), + checkpoint_key=checkpoint_key, + trackable=obj)) + unique_tasks.add( + saveable_object_util.set_cpu0(device.to_string())) + + self._num_unique_tasks = len(unique_tasks) self._registered_savers = {} if registered_savers: @@ -332,8 +316,13 @@ def __init__(self, self._registered_savers[registered_name] = (save_fn, restore_fn) @classmethod - def from_saveables(cls, saveables, registered_savers=None, - call_with_mapped_captures=None): + def from_saveables( + cls, + saveables: Sequence[base.Trackable], + registered_savers: "RegisteredSaversDict | None" = None, + call_with_mapped_captures: "MappedCapturesCallable | None" = None + ) -> "MultiDeviceSaver": + """Constructs a MultiDeviceSaver from a list of `SaveableObject`s.""" serialized_tensors = object_identity.ObjectIdentityDictionary() for saveable in saveables: trackable = saveable_object_util.SaveableCompatibilityConverter( @@ -341,7 +330,7 @@ def from_saveables(cls, saveables, registered_savers=None, serialized_tensors[trackable] = trackable._serialize_to_tensors() # pylint: disable=protected-access return cls(serialized_tensors, registered_savers, call_with_mapped_captures) - def to_proto(self): + def to_proto(self) -> saver_pb2.SaverDef: """Serializes to a SaverDef referencing the current graph.""" filename_tensor = array_ops.placeholder( shape=[], dtype=dtypes.string, name="saver_filename") @@ -356,7 +345,7 @@ def to_proto(self): @def_function.function( input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), autograph=False) - def _traced_save(self, file_prefix): + def _traced_save(self, file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: save_op = self.save(file_prefix) with ops.device("cpu:0"): with ops.control_dependencies([save_op]): @@ -365,13 +354,72 @@ def _traced_save(self, file_prefix): @def_function.function( input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), autograph=False) - def _traced_restore(self, file_prefix): + def _traced_restore( + self, file_prefix: tensor_lib.Tensor) -> tensor_lib.Tensor: restore_ops = self.restore(file_prefix) with ops.device("cpu:0"): with ops.control_dependencies(restore_ops.values()): return array_ops.identity(file_prefix) - def save(self, file_prefix, options=None): + def _get_shards_by_task( + self, + sharding_callback: sharding_util.ShardingCallback + ) -> Sequence[sharding_util.TensorSliceDict]: + """Calls the sharding callback with shardable_tensors. + + Args: + sharding_callback: ShardingCallback. The callback function wrapper that + splits shardable_tensors into shards. + + Returns: + A list of shards. + """ + shardable_tensors_by_task = {} + for shardable_tensor in self._shardable_tensors: + tensor_val = shardable_tensor.tensor + tensor_shape = shardable_tensor.shape + save_spec = shardable_tensor._tensor_save_spec # pylint: disable=protected-access + with ops.device(shardable_tensor.device): + save_spec_tensor = save_spec.tensor + + if tensor_val is None and save_spec_tensor is None: + # A tensor value of `None` indicates that this SaveableObject gets + # recorded in the object graph, but that no value is saved in the + # checkpoint. + continue + elif save_spec_tensor is not None: + # Pull the tensor value from _tensor_save_spec. + tensor_val = save_spec_tensor + tensor_shape = save_spec_tensor.shape + + # Propagate the save spec name and/or slice spec when they are tensors. + # This makes sure properties like `layout` for dtensor names/slice specs + # are preserved during sharding. + if isinstance(save_spec.name, tensor_lib.Tensor): + tensor_val._wrapped_name = save_spec.name # pylint: disable=protected-access + if isinstance(shardable_tensor.slice_spec, tensor_lib.Tensor): + tensor_val._wrapped_slice_spec = save_spec.slice_spec # pylint: disable=protected-access + + task = device_lib.DeviceSpec.from_string( + saveable_object_util.set_cpu0(shardable_tensor.device.to_string())) + shardable_tensors_by_task.setdefault(task, []).append(dataclasses.replace( + shardable_tensor, + tensor=tensor_val, + shape=tensor_shape + )) + + sharding_callback = ( + sharding_callback or sharding_policies.ShardByTaskPolicy()) + shards_by_task = [ + (task, sharding_callback(shardable_tensors)) + for task, shardable_tensors in shardable_tensors_by_task.items()] + return shards_by_task + + def save( + self, + file_prefix: tensor_lib.Tensor, + options: "checkpoint_options.CheckpointOptions | None" = None + ) -> ops.Operation: """Save the saveable objects to a checkpoint with `file_prefix`. Args: @@ -423,7 +471,7 @@ def save(self, file_prefix, options=None): for saver_name in self._registered_savers } - def save_fn(): + def save_fn() -> ops.Operation: saved_prefixes = [] # Save with the registered savers. These run before default savers due to # the API contract. @@ -439,31 +487,31 @@ def save_fn(): f"string type tensors. Got {maybe_saved_prefixes}.") saved_prefixes.extend(flattened_saved_prefixes) - # (Default saver) Save with single device savers. - num_shards = len(self._single_device_savers) + shards_by_task = self._get_shards_by_task( + options.experimental_sharding_callback) + num_shards_tensor = constant_op.constant( + sum([len(shards) for _, shards in shards_by_task]), name="num_shards") sharded_saves = [] - num_shards_tensor = constant_op.constant(num_shards, name="num_shards") - last_device = None - for shard, (device, saver) in enumerate( - sorted(self._single_device_savers.items())): - last_device = device - with ops.device(saveable_object_util.set_cpu0(device)): - shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, - num_shards_tensor) - saved_prefixes.append(shard_prefix) - with ops.device(device): - # _SingleDeviceSaver will use the CPU device when necessary, but - # initial read operations should be placed on the SaveableObject's - # device. - sharded_saves.append(saver.save(shard_prefix, options)) + + shard_idx = 0 + for task, shards in shards_by_task: + for shard in shards: + with ops.device(task): + shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard_idx, + num_shards_tensor) + shard_idx += 1 + saved_prefixes.append(shard_prefix) + sharded_saves.append( + _single_shard_save(shard_prefix, shard, task, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. - merge_device = ( + tensor_device_spec = self._shardable_tensors[-1].device + merge_device_spec = ( options.experimental_io_device or - saveable_object_util.set_cpu0(last_device)) - with ops.device(merge_device): + saveable_object_util.set_cpu0(tensor_device_spec.to_string())) + with ops.device(merge_device_spec): # V2 format write path consists of a metadata merge step. Once # merged, attempts to delete the temporary directory, # "_temp". @@ -471,19 +519,23 @@ def save_fn(): saved_prefixes, file_prefix, delete_old_dirs=True) # Since this will causes a function re-trace on each save, limit this to the - # cases where it is needed: eager and when there are multiple tasks/single - # device savers. Note that the retrace is needed to ensure we pickup the - # latest values of options like experimental_io_device. - if context.executing_eagerly() and len(self._single_device_savers) > 1: + # cases where it is needed: eager and when there are multiple tasks. Note + # that the retrace is needed to ensure we pickup the latest values of + # options like experimental_io_device. + if context.executing_eagerly() and self._num_unique_tasks > 1: # Explicitly place the identity op on the first device. @def_function.function(jit_compile=False) - def tf_function_save(): + def tf_function_save() -> None: save_fn() tf_function_save() else: return save_fn() - def restore(self, file_prefix, options=None): + def restore( + self, + file_prefix: tensor_lib.Tensor, + options: "checkpoint_options.CheckpointOptions | None" = None + ) -> Mapping[str, ops.Operation]: """Restore the saveable objects from a checkpoint with `file_prefix`. Args: @@ -498,18 +550,17 @@ def restore(self, file_prefix, options=None): """ options = options or checkpoint_options.CheckpointOptions() - def restore_fn(): + def restore_fn() -> Mapping[str, ops.Operation]: restore_fn_inputs = {} restore_fn_input_count = { fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()} restore_ops = {} - # Sort by device name to avoid propagating non-deterministic dictionary - # ordering in some Python versions. - for device, saver in sorted(self._single_device_savers.items()): - with ops.device(device): + if self._shardable_tensors: + with ops.device("CPU:0"): # Load values from checkpoint - restored_tensor_dict = saver.restore(file_prefix, options) + restored_tensor_dict = _single_shard_restore( + file_prefix, self._shardable_tensors, options) # Map restored tensors to the corresponding restore_fn, and see if all # inputs have all been loaded. Call `restore_fn` if that is the case. @@ -550,13 +601,12 @@ def restore_fn(): return restore_ops has_custom_device_saver = any([ - context.is_custom_device(d) for d in self._single_device_savers.keys() - ]) + context.is_custom_device(st.device.to_string()) + for st in self._shardable_tensors]) # Since this will cause a function re-trace on each restore, limit this to - # cases where it is needed: eager and when there are multiple tasks/single - # device savers or any single device saver is a custom device. Note that the - # retrace is needed to ensure we pickup the latest values of options like - # experimental_io_device. + # cases where it is needed: eager and when there are multiple tasks or any + # device_spec is a custom device. Note that the retrace is needed to ensure + # we pickup the latest values of options like experimental_io_device. # # We run in a function when there is a custom device saver because custom # devices, such as DTensor, usually do a sharded save and restore. @@ -564,10 +614,10 @@ def restore_fn(): # of variables we are restoring to. In practice, this means that custom # devices need the AssignVariableOps along with the Restore op within the # same graph to infer shapes and shard specs for Restore op. - if context.executing_eagerly() and (len(self._single_device_savers) > 1 or + if context.executing_eagerly() and (self._num_unique_tasks > 1 or has_custom_device_saver): @def_function.function(jit_compile=False, autograph=False) - def tf_function_restore(): + def tf_function_restore() -> Mapping[str, ops.Operation]: restore_fn() return {} diff --git a/tensorflow/python/checkpoint/functional_saver_test.py b/tensorflow/python/checkpoint/functional_saver_test.py index 3bac7428f2e030..954f8ed1c399b4 100644 --- a/tensorflow/python/checkpoint/functional_saver_test.py +++ b/tensorflow/python/checkpoint/functional_saver_test.py @@ -15,6 +15,7 @@ """Tests for the functional saver.""" import os +import time from tensorflow.python.checkpoint import checkpoint from tensorflow.python.checkpoint import checkpoint_options @@ -29,13 +30,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.module import module +from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import gfile from tensorflow.python.training import server_lib -from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util - LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0" @@ -53,35 +53,22 @@ def setUp(self): self.local_options = checkpoint_options.CheckpointOptions( experimental_io_device=LOCALHOST) - def _get_shardable_tensors(self, serialized_tensors): - shardable_tensors = [] - for obj, tensor_dict in serialized_tensors.items(): - # Divide tensor_dict by device. - for checkpoint_key, tensor_slice_dict in tensor_dict.items(): - if not isinstance(tensor_slice_dict, dict): - # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. - tensor_slice_dict = {"": tensor_slice_dict} - for slice_spec, tensor_save_spec in tensor_slice_dict.items(): - if not isinstance(tensor_save_spec, saveable_object.SaveSpec): - tensor_save_spec = saveable_object.SaveSpec( - tensor=tensor_save_spec, - slice_spec=slice_spec, - name=checkpoint_key, - dtype=tensor_save_spec.dtype, - device=tensor_save_spec.device) - save_spec_tensor = tensor_save_spec.tensor - shardable_tensors.append( - functional_saver.ShardableTensor( - _tensor_save_spec=tensor_save_spec, - tensor=save_spec_tensor, - dtype=tensor_save_spec.dtype, - device=tensor_save_spec.device, - name=tensor_save_spec.name, - shape=save_spec_tensor.shape, - slice_spec=slice_spec, - checkpoint_key=checkpoint_key, - trackable=obj)) - return shardable_tensors + def _get_tensors_by_task(self, root): + serialized_tensors, _, _, _ = ( + checkpoint.TrackableSaver(graph_view.ObjectGraphView(root)) + ._gather_serialized_tensors(None)) + + tensors_by_task = {} + for tensor_dict in serialized_tensors.values(): + for checkpoint_key, maybe_tensor in tensor_dict.items(): + if not isinstance(maybe_tensor, dict): + maybe_tensor = {"": maybe_tensor} + for slice_spec, tensor in maybe_tensor.items(): + tensor_task = saveable_object_util.set_cpu0(tensor.device) + (tensors_by_task + .setdefault(tensor_task, {}) + .setdefault(checkpoint_key, {})[slice_spec]) = tensor + return tensors_by_task @test_util.run_in_graph_and_eager_modes def test_resource_variable(self): @@ -220,40 +207,49 @@ def test_checkpoint_multi_device_using_localhost(self): if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): self.assertEqual(LOCALHOST, op.device) - def test_ShardByDevicePolicy(self): + def test_single_task_save_singlehost_multidevice(self): root = module.Module() with ops.device("cpu:0"): - v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): - v1 = resource_variable_ops.ResourceVariable(1.0, name="v1") + v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): - v2 = resource_variable_ops.ResourceVariable(2.0, name="v2") + v2 = resource_variable_ops.ResourceVariable(2.) root.v0 = v0 root.v1 = v1 root.v2 = v2 - serialized_tensors, _, _, _ = ( - checkpoint.TrackableSaver(graph_view.ObjectGraphView(root)) - ._gather_serialized_tensors(None)) - shardable_tensors = self._get_shardable_tensors(serialized_tensors) - callback = functional_saver.ShardByDevicePolicy() - shards = callback(shardable_tensors) + tensors_by_task = self._get_tensors_by_task(root) + var_names = [ + "v0/.ATTRIBUTES/VARIABLE_VALUE", + "v1/.ATTRIBUTES/VARIABLE_VALUE", + "v2/.ATTRIBUTES/VARIABLE_VALUE" + ] + vars_numpy = [v0.numpy(), v1.numpy(), v2.numpy()] + tmp_dir = self.get_temp_dir() + + for device in ["cpu:0", "cpu:1", "cpu:2"]: + for shard, (_, tensor_slice_dict) in enumerate( + sorted(tensors_by_task.items())[1:]): + with ops.device(device): + shard_prefix = gen_io_ops.sharded_filename( + os.path.join(tmp_dir, str(shard)), shard, 3) + functional_saver._single_task_save( + shard_prefix, tensor_slice_dict) - self.assertAllEqual( - [list(shard.keys()) for shard in shards], - [[ - "v0/.ATTRIBUTES/VARIABLE_VALUE", - "v1/.ATTRIBUTES/VARIABLE_VALUE", - "v2/.ATTRIBUTES/VARIABLE_VALUE", - "_CHECKPOINTABLE_OBJECT_GRAPH" - ]]) + start_time = time.time() + max_save_time = start_time + 5 # seconds + while not (gfile.ListDirectory(tmp_dir) or time.time() > max_save_time): + pass # eager execution is lovely + self.assertNotEmpty(gfile.ListDirectory(tmp_dir)) - self.assertEqual(shards[0]["v0/.ATTRIBUTES/VARIABLE_VALUE"][""].numpy(), - v0.numpy()) - self.assertEqual(shards[0]["v1/.ATTRIBUTES/VARIABLE_VALUE"][""].numpy(), - v1.numpy()) - self.assertEqual(shards[0]["v2/.ATTRIBUTES/VARIABLE_VALUE"][""].numpy(), - v2.numpy()) + with ops.device(device): + restored_dict = functional_saver._single_task_restore( + shard_prefix, tensor_slice_dict) + self.evaluate(restored_dict) + self.assertEqual( + restored_dict[var_names[shard]][""].numpy(), + vars_numpy[shard]) if __name__ == "__main__": diff --git a/tensorflow/python/checkpoint/sharding/BUILD b/tensorflow/python/checkpoint/sharding/BUILD new file mode 100644 index 00000000000000..412f8c4f12d050 --- /dev/null +++ b/tensorflow/python/checkpoint/sharding/BUILD @@ -0,0 +1,103 @@ +# Description: +# Utilities for sharding object-based checkpoints. + +load("//tensorflow:strict.default.bzl", "py_strict_library") +load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow:internal", + ], + licenses = ["notice"], +) + +py_strict_library( + name = "sharding_policies", + srcs = ["sharding_policies.py"], + srcs_version = "PY3", + deps = [ + ":sharding_util", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:string_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/trackable:base", + "//tensorflow/python/util:tf_export", + "@absl_py//absl/logging", + ], +) + +tf_py_strict_test( + name = "sharding_policies_test", + srcs = ["sharding_policies_test.py"], + srcs_version = "PY3", + deps = [ + ":sharding_policies", + ":sharding_util", + "//tensorflow/python/checkpoint", + "//tensorflow/python/checkpoint:checkpoint_options", + "//tensorflow/python/checkpoint:graph_view", + "//tensorflow/python/eager:remote", + "//tensorflow/python/eager:test", + "//tensorflow/python/framework:device", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/module", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:gfile", + "//tensorflow/python/training:server_lib", + "//tensorflow/python/training/saving:saveable_object", + "//tensorflow/python/training/saving:saveable_object_util", + "@absl_py//absl/logging", + ], +) + +py_strict_library( + name = "sharding_util", + srcs = ["sharding_util.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow/python/framework:device", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/ops:variables", + "//tensorflow/python/trackable:base", + "//tensorflow/python/training/saving:saveable_object", + "//tensorflow/python/util:tf_export", + "@absl_py//absl/logging", + ], +) + +tf_py_strict_test( + name = "sharding_util_test", + srcs = ["sharding_util_test.py"], + srcs_version = "PY3", + deps = [ + ":sharding_policies", + ":sharding_util", + "//tensorflow/python/checkpoint", + "//tensorflow/python/checkpoint:graph_view", + "//tensorflow/python/eager:remote", + "//tensorflow/python/eager:test", + "//tensorflow/python/framework:device", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", + "//tensorflow/python/module", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/training:server_lib", + "//tensorflow/python/training/saving:saveable_object", + "//tensorflow/python/training/saving:saveable_object_util", + ], +) diff --git a/tensorflow/python/checkpoint/sharding/sharding_policies.py b/tensorflow/python/checkpoint/sharding/sharding_policies.py new file mode 100644 index 00000000000000..5ee731fd96d979 --- /dev/null +++ b/tensorflow/python/checkpoint/sharding/sharding_policies.py @@ -0,0 +1,322 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Checkpoint policies that determine how tensors are split into shards.""" + +import math +from typing import MutableSequence, Sequence + +from absl import logging + +from tensorflow.python.checkpoint.sharding import sharding_util +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables +from tensorflow.python.trackable import base +from tensorflow.python.util import tf_export + + +@tf_export.tf_export("train.experimental.ShardByTaskPolicy") +class ShardByTaskPolicy(sharding_util.ShardingCallback): + """Policy that splits tensors into shards based on their device spec task.""" + + @property + def description(self) -> str: + return "Split tensors into shards based on their device spec task." + + def __call__( + self, + shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + """Callback to split tensors into shards based on their device spec task. + + Args: + shardable_tensors: A list of ShardableTensors. + + Returns: + List of shard dicts containing tensors. + [ {checkpoint key: {slice_spec: tensor} } ] + """ + tensors_by_task = {} + + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + + (tensors_by_task + .setdefault(checkpoint_key, {})[slice_spec]) = tensor + + return [tensors_by_task] + + +_PartitionAxisAndSize = tuple[int, int] +_OffsetAndShape = tuple[Sequence[int], Sequence[int]] + + +@tf_export.tf_export("train.experimental.MaxShardSizePolicy") +class MaxShardSizePolicy(sharding_util.ShardingCallback): + """Policy that splits tensors into shards with a max shard size. + + Shards may exceed the max shard size if they contain 1. a single scalar/string + tensor that could not be sliced and exceeds the max shard size or 2. the + checkpoint object graph, whose size cannot be calculated when saving. + """ + + def __init__(self, max_shard_size: int): + self.max_shard_size = max_shard_size + + @property + def description(self) -> str: + return "Split tensors into shards with a max shard size." + + def _get_next_partition( + self, + shard_size_remaining: int, + shape: tensor_shape.TensorShape, + dtype_size: int, + num_elems: int + ) -> _PartitionAxisAndSize: + """Gets tensor partition with size closest to shard_size_remaining. + + Args: + shard_size_remaining: Size in bytes of the space remaining in the shard. + shape: Shape of the working tensor to partition in the remaining + shard space. + dtype_size: Size in bytes of the dtype of the working tensor. + num_elems: Number of elements in the working tensor. + + Returns: + A tuple containing the axis of the next partition and that partition size. + """ + if shape.rank is None or shape.rank == 0: + return 0, math.inf + + # Find axis with minimum partitions. (aka axis with maximum partition size) + # (max partition size is as close as possible to the shard_size_remaining) + bytes_per_slice = num_elems // shape.dims[0].value * dtype_size + slices_per_shard = max( + 1, math.floor(shard_size_remaining / bytes_per_slice)) + min_parts = math.ceil(shape.dims[0].value / slices_per_shard) + min_axis = 0 + for axis in range(1, shape.rank): + bytes_per_slice = num_elems // shape.dims[axis].value * dtype_size + slices_per_shard = max( + 1, math.floor(shard_size_remaining / bytes_per_slice)) + axis_parts = math.ceil(shape.dims[axis].value / slices_per_shard) + partition_size = num_elems * dtype_size / axis_parts + if (axis_parts < min_parts and + partition_size < shard_size_remaining): + min_axis, min_parts = axis, int(axis_parts) + return min_axis, math.ceil(int(shape[min_axis]) / min_parts) + + def _add_partition( + self, + root_shardable_tensor: sharding_util.ShardableTensor, + dtype_size: int, + working_tensor_offset: Sequence[int], + part_axis_and_size: _PartitionAxisAndSize, + shard_size_remaining: int, + max_shard_size: int, + tensors_by_shard: MutableSequence[sharding_util.TensorSliceDict], + large_scalars: MutableSequence[sharding_util.TensorSliceDict], + ) -> tuple[tensor_lib.Tensor, _OffsetAndShape]: + """Adds the tensor partition to the shard, if possible. + + Args: + root_shardable_tensor: The full tensor being partitioned. + dtype_size: Size in bytes of the dtype of the working tensor. + working_tensor_offset: The offset of the working tensor in the full + tensor. + part_axis_and_size: A tuple containing the axis of the partition and that + partition size. + shard_size_remaining: Size in bytes of the space remaining in the shard. + max_shard_size: Max size in bytes allowed for a checkpoint shard. + tensors_by_shard: List of shard dicts containing tensors. + [ {checkpoint key: {slice_spec: tensor} } ] + large_scalars: List of shard dicts containing scalars too large to fit in + the max_shard_size. [ {checkpoint key: {slice_spec: tensor} } ] + + Returns: + A tuple containing the size of the slice that was added to the shard and + the offset & shape of the remaining portion of the tensor. + """ + root_tensor = root_shardable_tensor.tensor + root_tensor_shape = root_shardable_tensor.shape + checkpoint_key = root_shardable_tensor.checkpoint_key + + if root_tensor_shape.rank is None or root_tensor_shape.rank == 0: + return None, (None, None) + + min_axis, part_size = part_axis_and_size + + # Add what we can to the current shard. + slice_offset = working_tensor_offset + slice_shape = [root_tensor_shape[i] - slice_offset[i] + for i in range(root_tensor_shape.rank)] + slice_shape[min_axis] = part_size + slice_size_in_bytes = int(math.prod(slice_shape)) * dtype_size + with ops.device(root_shardable_tensor.device): + tensor_slice = array_ops.slice( + root_tensor, begin=slice_offset, size=slice_shape) + slice_spec = variables.Variable.SaveSliceInfo( + full_name=checkpoint_key, + full_shape=root_tensor_shape, + var_offset=slice_offset, + var_shape=slice_shape).spec.strip() + remaining_size = shard_size_remaining + if slice_size_in_bytes > max_shard_size: + logging.warning("Slice %s of tensor %s is a scalar of size %s bytes and " + "cannot be partitioned into a shard of max shard size %s " + "bytes. It will be added as an individual shard that " + "exceeds the max shard size.", slice_spec, checkpoint_key, + slice_size_in_bytes, max_shard_size) + large_scalars.append({checkpoint_key: {slice_spec: tensor_slice}}) + elif slice_size_in_bytes > shard_size_remaining: + # Smallest partition can't fit in the remaining shard space. Start fresh + # with a new shard. + return None, (None, None) + else: + if not tensors_by_shard or shard_size_remaining < 1: + tensors_by_shard.append({}) + remaining_size = max_shard_size + (tensors_by_shard[-1] + .setdefault(checkpoint_key, {})[slice_spec]) = tensor_slice + remaining_size -= slice_size_in_bytes + + # Get remaining portion of tensor to add to the next shard(s). + slice_offset[min_axis] += part_size + slice_shape = [root_tensor_shape[i] - slice_offset[i] + for i in range(root_tensor_shape.rank)] + + return (remaining_size, (slice_offset, slice_shape)) + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + """Callback to split tensors into shards with a max shard size. + + Args: + shardable_tensors: A list of ShardableTensors. + + Returns: + List of shard dicts containing tensors. + [ {checkpoint key: {slice_spec: tensor} } ] + """ + tensors_by_shard = [] + large_scalars = [] + + shard_size_remaining = self.max_shard_size + for shardable_tensor in shardable_tensors: + root_tensor = shardable_tensor.tensor + root_shape = shardable_tensor.shape + dtype = shardable_tensor.dtype + checkpoint_key = shardable_tensor.checkpoint_key + + dtype_size = dtypes.as_dtype(dtype).size + total_size = root_shape.num_elements() * dtype_size # in bytes + + # Calculate string tensor sizes. + if checkpoint_key == base.OBJECT_GRAPH_PROTO_KEY: + # In graph mode, the object graph is populated using feed_additions when + # the session is run. So, we can't calculate the size here. Fortunately, + # the serialized object graph string will never be that big, so we just + # place it in the current shard without worrying about its size. + total_size = dtype_size = 0 + elif dtype == dtypes.string: + if not context.executing_eagerly(): + with ops.device(shardable_tensor.device): + root_tensor = ops.get_default_session().run(root_tensor) + + if root_shape.rank is None or root_shape.rank == 0: + sizes = [string_ops.string_length(root_tensor, unit="BYTE")] + else: + sizes = [string_ops.string_length(elem, unit="BYTE") + for elem in root_tensor] + + if context.executing_eagerly(): + sizes = [size.numpy() for size in sizes] + else: + with ops.device(shardable_tensor.device): + sizes = ops.get_default_session().run(sizes) + + total_size = sum(sizes) + dtype_size = max(sizes) + + if (total_size > self.max_shard_size and + (root_shape.rank is None or root_shape.rank == 0)): + logging.warning("Tensor %s is a scalar of size %s bytes and cannot be " + "partitioned into a shard of max shard size %s bytes. " + "It will be added as an individual shard that exceeds " + "the max shard size.", + checkpoint_key, total_size, self.max_shard_size) + large_scalars.append( + {checkpoint_key: {shardable_tensor.slice_spec: root_tensor}}) + continue + + # Partition tensor and add partitions to shards. + working_tensor = root_tensor + working_tensor_var_offset = [0] * root_shape.rank + working_tensor_shape = root_shape + working_tensor_size = total_size + while working_tensor_size > shard_size_remaining: + part_axis_and_size = self._get_next_partition( + shard_size_remaining=shard_size_remaining, + shape=working_tensor_shape, + dtype_size=dtype_size, + num_elems=working_tensor_shape.num_elements()) + + (remaining_size, + (remaining_offset, remaining_shape)) = self._add_partition( + root_shardable_tensor=shardable_tensor, + dtype_size=dtype_size, + working_tensor_offset=working_tensor_var_offset, + part_axis_and_size=part_axis_and_size, + shard_size_remaining=shard_size_remaining, + max_shard_size=self.max_shard_size, + tensors_by_shard=tensors_by_shard, + large_scalars=large_scalars) + + if remaining_size is None: + # Tensor partition couldn't fit in remaining shard space. Try again + # with the next full shard. + tensors_by_shard.append({}) + shard_size_remaining = self.max_shard_size + else: + working_tensor = array_ops.slice( + root_tensor, begin=remaining_offset, size=remaining_shape) + working_tensor_var_offset = remaining_offset + working_tensor_shape = working_tensor.shape + working_tensor_size = int(math.prod(remaining_shape)) * dtype_size + shard_size_remaining = remaining_size + + if working_tensor_shape.num_elements() > 0: + remaining_tensor_slice_spec = variables.Variable.SaveSliceInfo( + full_name=checkpoint_key, + full_shape=root_shape, + var_offset=working_tensor_var_offset, + var_shape=working_tensor_shape).spec.strip() + if not tensors_by_shard: + tensors_by_shard.append({}) + (tensors_by_shard[-1] + .setdefault(checkpoint_key, {}) + [remaining_tensor_slice_spec]) = working_tensor + shard_size_remaining -= working_tensor_size + + return tensors_by_shard + large_scalars diff --git a/tensorflow/python/checkpoint/sharding/sharding_policies_test.py b/tensorflow/python/checkpoint/sharding/sharding_policies_test.py new file mode 100644 index 00000000000000..133a0b923d6338 --- /dev/null +++ b/tensorflow/python/checkpoint/sharding/sharding_policies_test.py @@ -0,0 +1,697 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for checkpoint sharding policies.""" + +import random +import string + +from tensorflow.python.checkpoint import checkpoint +from tensorflow.python.checkpoint import checkpoint_options +from tensorflow.python.checkpoint import graph_view +from tensorflow.python.checkpoint.sharding import sharding_policies +from tensorflow.python.checkpoint.sharding import sharding_util +from tensorflow.python.eager import remote +from tensorflow.python.eager import test +from tensorflow.python.framework import device as device_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.module import module +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.training import server_lib +from tensorflow.python.training.saving import saveable_object +from tensorflow.python.training.saving import saveable_object_util + + +class ShardingPoliciesTest(test.TestCase): + + def _get_shardable_tensors_by_task(self, root): + serialized_tensors, _, _, _ = ( + checkpoint.TrackableSaver(graph_view.ObjectGraphView(root)) + ._gather_serialized_tensors(None)) + + shardable_tensors_by_task = {} + for obj, tensor_dict in serialized_tensors.items(): + # Divide tensor_dict by device. + for checkpoint_key, tensor_slice_dict in tensor_dict.items(): + if not isinstance(tensor_slice_dict, dict): + # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. + tensor_slice_dict = {"": tensor_slice_dict} + for slice_spec, tensor_save_spec in tensor_slice_dict.items(): + if not isinstance(tensor_save_spec, saveable_object.SaveSpec): + tensor_save_spec = saveable_object.SaveSpec( + tensor=tensor_save_spec, + slice_spec=slice_spec, + name=checkpoint_key, + dtype=tensor_save_spec.dtype, + device=tensor_save_spec.device) + save_spec_tensor = tensor_save_spec.tensor + device = (device_lib.DeviceSpec.from_string(tensor_save_spec.device) + if isinstance(tensor_save_spec.device, str) + else tensor_save_spec.device) + task = device_lib.DeviceSpec.from_string( + saveable_object_util.set_cpu0(device.to_string())) + shardable_tensors_by_task.setdefault(task, []).append( + sharding_util.ShardableTensor( + _tensor_save_spec=tensor_save_spec, + tensor=save_spec_tensor, + dtype=tensor_save_spec.dtype, + device=device, + name=tensor_save_spec.name, + shape=save_spec_tensor.shape, + slice_spec=slice_spec, + checkpoint_key=checkpoint_key, + trackable=obj)) + return shardable_tensors_by_task.values() + + def test_ShardByTaskPolicy(self): + servers = [server_lib.Server.create_local_server() for _ in range(3)] + cluster_spec = server_lib.ClusterSpec({ + "worker": [s.target[len("grpc://"):] for s in servers]}) + remote.connect_to_cluster(cluster_spec) + root = module.Module() + with ops.device("/job:worker/task:0/cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + with ops.device("/job:worker/task:1/cpu:0"): + v1 = resource_variable_ops.ResourceVariable(1.0, name="v1") + with ops.device("/job:worker/task:2/cpu:0"): + v2 = resource_variable_ops.ResourceVariable(2.0, name="v2") + root.v0 = v0 + root.v1 = v1 + root.v2 = v2 + + shardable_tensors = self._get_shardable_tensors_by_task(root) + + callback = sharding_policies.ShardByTaskPolicy() + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertAllEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE"}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE"}, + {"v2/.ATTRIBUTES/VARIABLE_VALUE"}, + {"_CHECKPOINTABLE_OBJECT_GRAPH"} + ]) + + self.assertEqual( + self.evaluate(shards[0]["v0/.ATTRIBUTES/VARIABLE_VALUE"][""]), + v0.numpy()) + self.assertEqual( + self.evaluate(shards[1]["v1/.ATTRIBUTES/VARIABLE_VALUE"][""]), + v1.numpy()) + self.assertEqual( + self.evaluate(shards[2]["v2/.ATTRIBUTES/VARIABLE_VALUE"][""]), + v2.numpy()) + + def test_CheckpointOption_ShardByTaskPolicy(self): + servers = [server_lib.Server.create_local_server() for _ in range(3)] + cluster_spec = server_lib.ClusterSpec({ + "worker": [s.target[len("grpc://"):] for s in servers]}) + remote.connect_to_cluster(cluster_spec) + root = module.Module() + with ops.device("/job:worker/task:0/cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + self.evaluate(v0.initializer) + with ops.device("/job:worker/task:1/cpu:0"): + v1 = resource_variable_ops.ResourceVariable(1.0, name="v1") + self.evaluate(v1.initializer) + with ops.device("/job:worker/task:2/cpu:0"): + v2 = resource_variable_ops.ResourceVariable(2.0, name="v2") + self.evaluate(v2.initializer) + root.v0 = v0 + root.v1 = v1 + root.v2 = v2 + + tmp_dir = self.create_tempdir("ckpt") + ckpt = checkpoint.Checkpoint(root) + save_path = ckpt.save( + tmp_dir, options=checkpoint_options.CheckpointOptions( + experimental_sharding_callback=( + sharding_policies.ShardByTaskPolicy()))) + self.assertLen(gfile.Glob(save_path + ".data*"), 4) + ckpt.restore(save_path) + + @test_util.run_in_graph_and_eager_modes + def test_MaxShardSizePolicy_1D(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0, 3.0], + name="v0", + dtype=dtypes.float32) + v1 = resource_variable_ops.ResourceVariable([[4], + [5], + [6], + [7]], + name="v1", + dtype=dtypes.int32) + self.evaluate(v0.initializer) + self.evaluate(v1.initializer) + root.v0 = v0 + root.v1 = v1 + + v0_name = "v0/.ATTRIBUTES/VARIABLE_VALUE" + v1_name = "v1/.ATTRIBUTES/VARIABLE_VALUE" + + class V0SaveSliceInfo(variables.Variable.SaveSliceInfo): + def __init__(self, var_offset, var_shape): + super().__init__( + full_name=v0_name, + full_shape=tensor_shape.TensorShape(dims=[4]), + var_offset=var_offset, + var_shape=var_shape) + + class V1SaveSliceInfo(variables.Variable.SaveSliceInfo): + def __init__(self, var_offset, var_shape): + super().__init__( + full_name=v1_name, + full_shape=tensor_shape.TensorShape(dims=[4, 1]), + var_offset=var_offset, + var_shape=var_shape) + + shardable_tensors = self._get_shardable_tensors_by_task(root) + + # Test sharding the v0 & v1 tensors with different max shard sizes. + + # max_shard_size: 4 bytes + # Each element of v0/v1 is a 32 bit/4 byte value, so each variable should be + # split into 4 shards. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=4) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0], var_shape=[1]).spec + self.assertEqual(self.evaluate(shards[0][v0_name][slice_spec]), 0.0) + + slice_spec = V0SaveSliceInfo(var_offset=[1], var_shape=[1]).spec + self.assertEqual(self.evaluate(shards[1][v0_name][slice_spec]), 1.0) + + slice_spec = V0SaveSliceInfo(var_offset=[2], var_shape=[1]).spec + self.assertEqual(self.evaluate(shards[2][v0_name][slice_spec]), 2.0) + + slice_spec = V0SaveSliceInfo(var_offset=[3], var_shape=[1]).spec + self.assertEqual(self.evaluate(shards[3][v0_name][slice_spec]), 3.0) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0], var_shape=[1, 1]).spec + self.assertEqual(self.evaluate(shards[4][v1_name][slice_spec]), [4]) + + slice_spec = V1SaveSliceInfo(var_offset=[1, 0], var_shape=[1, 1]).spec + self.assertEqual(self.evaluate(shards[5][v1_name][slice_spec]), [5]) + + slice_spec = V1SaveSliceInfo(var_offset=[2, 0], var_shape=[1, 1]).spec + self.assertEqual(self.evaluate(shards[6][v1_name][slice_spec]), [6]) + + slice_spec = V1SaveSliceInfo(var_offset=[3, 0], var_shape=[1, 1]).spec + self.assertEqual(self.evaluate(shards[7][v1_name][slice_spec]), [7]) + + # max_shard_size: 8 bytes + # v0/v1 haven't changed, so they should now be split into 2 shards each. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=8) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0], var_shape=[2]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [0.0, 1.0]) + + slice_spec = V0SaveSliceInfo(var_offset=[2], var_shape=[2]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [2.0, 3.0]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0], var_shape=[2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[2][v1_name][slice_spec]), [[4], [5]]) + + slice_spec = V1SaveSliceInfo(var_offset=[2, 0], var_shape=[2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[3][v1_name][slice_spec]), [[6], [7]]) + + # max_shard_size: 10 bytes + # 10 bytes is an uneven boundary for 4 byte elements. v0/v1 should be split + # into 2 shards each. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=10) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0], var_shape=[2]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [0.0, 1.0]) + + slice_spec = V0SaveSliceInfo(var_offset=[2], var_shape=[2]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [2.0, 3.0]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0], var_shape=[2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[2][v1_name][slice_spec]), [[4], [5]]) + + slice_spec = V1SaveSliceInfo(var_offset=[2, 0], var_shape=[2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[3][v1_name][slice_spec]), [[6], [7]]) + + # max_shard_size: 16 bytes + # 16 bytes the exact size of each variable, so they should get 1 shard each. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=16) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0], var_shape=[4]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [0.0, 1.0, 2.0, 3.0]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0], var_shape=[4, 1]).spec + self.assertAllEqual( + self.evaluate(shards[1][v1_name][slice_spec]), [[4], [5], [6], [7]]) + + # max_shard_size: 18 bytes + # 18 bytes slightly larger than the size of each variable, but not large + # enough to fit another 4 byte element, so they should get 1 shard each. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=18) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0], var_shape=[4]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [0.0, 1.0, 2.0, 3.0]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0], var_shape=[4, 1]).spec + self.assertAllEqual( + self.evaluate(shards[1][v1_name][slice_spec]), [[4], [5], [6], [7]]) + + @test_util.run_in_graph_and_eager_modes + def test_MaxShardSizePolicy_2D(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable([[0, 1], + [2, 3], + [4, 5]], + name="v0") + v1 = resource_variable_ops.ResourceVariable([[[6.0], [7.0]], + [[8.0], [9.0]], + [[10.0], [11.0]]], name="v1") + self.evaluate(v0.initializer) + self.evaluate(v1.initializer) + root.v0 = v0 + root.v1 = v1 + + v0_name = "v0/.ATTRIBUTES/VARIABLE_VALUE" + v1_name = "v1/.ATTRIBUTES/VARIABLE_VALUE" + + class V0SaveSliceInfo(variables.Variable.SaveSliceInfo): + def __init__(self, var_offset, var_shape): + super().__init__( + full_name=v0_name, + full_shape=tensor_shape.TensorShape(dims=[3, 2]), + var_offset=var_offset, + var_shape=var_shape) + + class V1SaveSliceInfo(variables.Variable.SaveSliceInfo): + def __init__(self, var_offset, var_shape): + super().__init__( + full_name=v1_name, + full_shape=tensor_shape.TensorShape(dims=[3, 2, 1]), + var_offset=var_offset, + var_shape=var_shape) + + shardable_tensors = self._get_shardable_tensors_by_task(root) + + # Test sharding the v0 & v1 tensors with different max shard sizes. + + # max_shard_size: 8 bytes + # Each element of v0/v1 is a 32 bit/4 byte value, so each variable should be + # split into 3 shards. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=8) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [[0, 1]]) + + slice_spec = V0SaveSliceInfo(var_offset=[1, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [[2, 3]]) + + slice_spec = V0SaveSliceInfo(var_offset=[2, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[2][v0_name][slice_spec]), [[4, 5]]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[3][v1_name][slice_spec]), [[[6.0], [7.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[1, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[4][v1_name][slice_spec]), [[[8.0], [9.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[2, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[5][v1_name][slice_spec]), [[[10.0], [11.0]]]) + + # max_shard_size: 10 bytes + # 10 bytes is an uneven boundary for 4 byte elements. v0/v1 should be split + # into 3 shards each. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=10) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [[0, 1]]) + + slice_spec = V0SaveSliceInfo(var_offset=[1, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [[2, 3]]) + + slice_spec = V0SaveSliceInfo(var_offset=[2, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[2][v0_name][slice_spec]), [[4, 5]]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[3][v1_name][slice_spec]), [[[6.0], [7.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[1, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[4][v1_name][slice_spec]), [[[8.0], [9.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[2, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[5][v1_name][slice_spec]), [[[10.0], [11.0]]]) + + # max_shard_size: 12 bytes + # 12 bytes is enough to fit 3 elements per variable in each shard, BUT that + # would require concurrent multidimensional tensor partitioning, which is + # not currently implemented for MaxShardSizePolicy. (When partitioning a + # tensor into a shard, we choose an axis to partition along. This can + # happen multiple times for a given tensor (in the case that the tensor + # spans multiple shards). In that case, multiple dimensions can be + # partitioned along (each time the tensor is partitioned, a new axis can be + # chosen), but not within a single iteration of adding a tensor partition to + # the shard.) So, v0/v1 should be split into 3 shards each. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=12) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [[0, 1]]) + + slice_spec = V0SaveSliceInfo(var_offset=[1, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [[2, 3]]) + + slice_spec = V0SaveSliceInfo(var_offset=[2, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[2][v0_name][slice_spec]), [[4, 5]]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[3][v1_name][slice_spec]), [[[6.0], [7.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[1, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[4][v1_name][slice_spec]), [[[8.0], [9.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[2, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[5][v1_name][slice_spec]), [[[10.0], [11.0]]]) + + # max_shard_size: 16 bytes + # Each variable should be split into 1.5 shards. The middle shard will + # contain elements from both variables. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=16) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE", "v1/.ATTRIBUTES/VARIABLE_VALUE"}, + {"v1/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + # V0 + slice_spec = V0SaveSliceInfo(var_offset=[0, 0], var_shape=[2, 2]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [[0, 1], [2, 3]]) + + slice_spec = V0SaveSliceInfo(var_offset=[2, 0], var_shape=[1, 2]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [[4, 5]]) + + # V1 + slice_spec = V1SaveSliceInfo(var_offset=[0, 0, 0], var_shape=[1, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[1][v1_name][slice_spec]), [[[6.0], [7.0]]]) + + slice_spec = V1SaveSliceInfo(var_offset=[1, 0, 0], var_shape=[2, 2, 1]).spec + self.assertAllEqual( + self.evaluate(shards[2][v1_name][slice_spec]), + [[[8.0], [9.0]], [[10.0], [11.0]]]) + + @test_util.run_in_graph_and_eager_modes + def test_MaxShardSizePolicy_Strings(self): + v_strings = [ + "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + for _ in range(4)] + + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(v_strings, name="v0", + dtype=dtypes.string) + self.evaluate(v0.initializer) + root.v0 = v0 + + v0_name = "v0/.ATTRIBUTES/VARIABLE_VALUE" + + class V0SaveSliceInfo(variables.Variable.SaveSliceInfo): + def __init__(self, var_offset, var_shape): + super().__init__( + full_name=v0_name, + full_shape=tensor_shape.TensorShape(dims=[4]), + var_offset=var_offset, + var_shape=var_shape) + + shardable_tensors = self._get_shardable_tensors_by_task(root) + + # Test sharding the v0 & v1 tensors with different max shard sizes. + + # max_shard_size: 10 bytes + # Each string in v0 is 10 bytes, so there should be 1 string per shard. + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=10) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE", "_CHECKPOINTABLE_OBJECT_GRAPH",} + ]) + + slice_spec = V0SaveSliceInfo(var_offset=[0], var_shape=[1]).spec + self.assertAllEqual( + self.evaluate(shards[0][v0_name][slice_spec]), [v_strings[0]]) + + slice_spec = V0SaveSliceInfo(var_offset=[1], var_shape=[1]).spec + self.assertAllEqual( + self.evaluate(shards[1][v0_name][slice_spec]), [v_strings[1]]) + + slice_spec = V0SaveSliceInfo(var_offset=[2], var_shape=[1]).spec + self.assertAllEqual( + self.evaluate(shards[2][v0_name][slice_spec]), [v_strings[2]]) + + slice_spec = V0SaveSliceInfo(var_offset=[3], var_shape=[1]).spec + self.assertAllEqual( + self.evaluate(shards[3][v0_name][slice_spec]), [v_strings[3]]) + + @test_util.run_in_graph_and_eager_modes + def test_MaxShardSizePolicy_LargeScalar(self): + v_string = "".join(random.choices( + string.ascii_uppercase + string.digits, k=10)).encode("utf-8") + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable( + v_string, name="v0", dtype=dtypes.string) + self.evaluate(v0.initializer) + root.v0 = v0 + + v0_name = "v0/.ATTRIBUTES/VARIABLE_VALUE" + + shardable_tensors = self._get_shardable_tensors_by_task(root) + + # max_shard_size: 8 bytes + callback = sharding_policies.MaxShardSizePolicy(max_shard_size=8) + shards = [] + for tensors in shardable_tensors: + shards.extend(callback(tensors)) + + self.assertEqual( + [set(shard.keys()) for shard in shards], + [ + {"_CHECKPOINTABLE_OBJECT_GRAPH",}, + {"v0/.ATTRIBUTES/VARIABLE_VALUE",} + ]) + + tensor_val = (self.evaluate(shards[1][v0_name][""]) + if ops.context.executing_eagerly() + else shards[1][v0_name][""]) + self.assertEqual(tensor_val, v_string) + + @test_util.run_in_graph_and_eager_modes + def test_CheckpointOption_MaxShardSizePolicy(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable([[0, 1], + [2, 3], + [4, 5]], + name="v0") + v1 = resource_variable_ops.ResourceVariable([[[6.0], [7.0]], + [[8.0], [9.0]], + [[10.0], [11.0]]], name="v1") + v2 = resource_variable_ops.ResourceVariable("test_string", name="v1") + self.evaluate(v0.initializer) + self.evaluate(v1.initializer) + self.evaluate(v2.initializer) + root.v0 = v0 + root.v1 = v1 + root.v2 = v2 + + tmp_dir = self.create_tempdir("ckpt") + ckpt = checkpoint.Checkpoint(root) + save_path = ckpt.save( + tmp_dir, options=checkpoint_options.CheckpointOptions( + experimental_sharding_callback=( + sharding_policies.MaxShardSizePolicy(max_shard_size=10)))) + self.assertLen(gfile.Glob(save_path + ".data*"), 8) + ckpt.restore(save_path) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/checkpoint/sharding/sharding_util.py b/tensorflow/python/checkpoint/sharding/sharding_util.py new file mode 100644 index 00000000000000..322bba18dcfa84 --- /dev/null +++ b/tensorflow/python/checkpoint/sharding/sharding_util.py @@ -0,0 +1,263 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Data structures and utilities for checkpoint sharding.""" + +import abc +import dataclasses +import inspect +from typing import Hashable, MutableMapping, Sequence + +from tensorflow.python.framework import device as device_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import variables +from tensorflow.python.trackable import base +from tensorflow.python.training.saving import saveable_object +from tensorflow.python.util import tf_export + + +TensorSlice = MutableMapping[tensor_spec.TensorSpec, tensor_lib.Tensor] +TensorSliceDict = MutableMapping[str, TensorSlice] + + +@tf_export.tf_export("train.experimental.ShardableTensor") +@dataclasses.dataclass(frozen=True) +class ShardableTensor: + """Tensor wrapper containing data necessary for sharding. + + The tensor representation used as inputs to pre-made and custom + `tf.train.experiemental.ShardingCallback`s, which can be specified using the + `experimental_sharding_callback` option in `tf.train.CheckpointOptions`. + + """ + _tensor_save_spec: saveable_object.SaveSpec + tensor: tensor_lib.Tensor + dtype: dtypes.DType + device: device_lib.DeviceSpec + name: str + shape: tensor_shape.TensorShape + slice_spec: variables.Variable.SaveSliceInfo + checkpoint_key: str + trackable: base.Trackable + + def __hash__(self) -> int: + return hash((self.name, self.dtype, str(self.device), self.checkpoint_key)) + + def __repr__(self) -> str: + return (f"\n{self.__class__.__name__}:\n" + f" _tensor_save_spec={self._tensor_save_spec!r}\n" + f" tensor={self.tensor!r}\n" + f" dtype={self.dtype!r}\n" + f" device={self.device!r}\n" + f" name={self.name!r}\n" + f" shape={self.shape!r}\n" + f" slice_spec={self.slice_spec!r}\n" + f" checkpoint_key={self.checkpoint_key!r}\n" + f" trackable={self.trackable!r}") + + +@tf_export.tf_export("train.experimental.ShardingCallback") +class ShardingCallback(abc.ABC): + """Checkpoint sharding callback function, along with a text description. + + A callback function wrapper that will be executed to determine how tensors + will be split into shards when the saver writes the checkpoint shards to disk. + + The callback takes a list of `tf.train.experimental.ShardableTensor`s as input + (as well as any kwargs defined by the `tf.train.experimental.ShardingCallback` + subclass), and organizes the input tensors into different shards. Tensors are + first organized by device task (see `tf.DeviceSpec`), then the callback will + be called for each collection of tensors. + + There are a few restrictions to keep in mind when creating a custom callback: + - Tensors must not be removed from the checkpoint. + - Tensors must not be reshaped. + - Tensor dtypes must not change. + - Tensors within a shard must belong to the same task. + Validation checks will be performed after the callback function is executed to + ensure these restrictions aren't violated. + + Here's an example of a simple custom callback: + + ``` + # Place all tensors in a single shard. + class AllInOnePolicy(tf.train.experimental.ShardingCallback): + @property + def description(self): + return "Place all tensors in a single shard." + + def __call__(self, shardable_tensors): + tensors = {} + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor_save_spec.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + + tensors.set_default(checkpoint_key, {})[slice_spec] = tensor + return [tensors] + + ckpt.save( + "path", + options=tf.train.CheckpointOptions( + experimental_sharding_callback=AllInOnePolicy())) + ``` + + The `description` attribute is used to identify the callback and to aid + debugging during saving and restoration. + + To take in kwargs, simply define the constructor and pass them in: + + ``` + class ParameterPolicy(tf.train.experimental.ShardingCallback): + def __init__(self, custom_param): + self.custom_param = custom_param + ... + + ckpt.save( + "path", + options=tf.train.CheckpointOptions( + experimental_sharding_callback=ParameterPolicy(custom_param=...))) + ``` + + """ + description: str + + @property + @abc.abstractmethod + def description(self) -> str: + pass + + @abc.abstractmethod + def __call__( + self, shardable_tensors: Sequence[ShardableTensor] + ) -> Sequence[TensorSliceDict]: + pass + + def __hash__(self) -> int: + hash_val = hash(self.description) + # vars() only includes user-defined attributes. + for attr_name, attr_val in vars(self).items(): + if not (inspect.ismethod(attr_val) or inspect.isfunction(attr_val)): + hash_val ^= hash(attr_name) + if isinstance(attr_val, Hashable): + hash_val ^= hash(attr_val) + return hash_val + + +def validate_shards( + shards: Sequence[TensorSliceDict], + shardable_tensors: Sequence[ShardableTensor], + callback_description: str +) -> None: + """Validates shards generated by the sharding_callback.""" + unseen_tensor_dict = {} + for shardable_tensor in shardable_tensors: + unseen_tensor_dict.setdefault( + shardable_tensor.checkpoint_key, {} + )[shardable_tensor.slice_spec] = shardable_tensor.tensor + seen_tensor_set = set() + + for shard_tensors in shards: + task_tensor = None + for checkpoint_key, tensor_slice_dict in shard_tensors.items(): + for slice_spec, shard_tensor in tensor_slice_dict.items(): + slice_spec = slice_spec.strip() + + # Validate uniqueness. + if (checkpoint_key, slice_spec) in seen_tensor_set: + raise RuntimeError( + "After executing the checkpoint sharding callback, multiple " + "tensors with the same checkpoint key and slice spec were " + "found:\n" + f" callback_description: {callback_description}\n" + f" checkpoint_key: {checkpoint_key}\n" + f" slice_spec: {slice_spec}\n") + + # Validate no added tensors. + if checkpoint_key not in unseen_tensor_dict: + raise RuntimeError( + "After executing the checkpoint sharding callback, a tensor " + "not originally in the object graph was found in the " + "checkpoint shards:\n" + f" callback_description: {callback_description}\n" + f" checkpoint_key: {checkpoint_key}\n" + f" slice_spec: {slice_spec}\n") + + # Validate no shape change. + target_shape = unseen_tensor_dict[checkpoint_key][slice_spec].shape + if shard_tensor.shape != target_shape: + raise RuntimeError( + "After executing the checkpoint sharding callback, a tensor " + "was found with an altered shape:\n" + f" callback_description: {callback_description}\n" + f" checkpoint_key: {checkpoint_key}\n" + f" slice_spec: {slice_spec}\n" + f" original tensor_shape: {target_shape}\n" + f" new tensor_shape: {shard_tensor.shape}\n") + + # Validate no dtype change. + target_dtype = unseen_tensor_dict[checkpoint_key][slice_spec].dtype + if shard_tensor.dtype != target_dtype: + raise RuntimeError( + "After executing the checkpoint sharding callback, a tensor " + "was found with an altered dtype:\n" + f" callback_description: {callback_description}\n" + f" checkpoint_key: {checkpoint_key}\n" + f" slice_spec: {slice_spec}\n" + f" original tensor_dtype: {target_dtype}\n" + f" new tensor_dtype: {shard_tensor.dtype}\n") + + # Validate same task in shard. + if task_tensor is None: + task_tensor = ShardableTensor + task_tensor.device = shard_tensor.device + task_tensor.checkpoint_key = checkpoint_key + task_tensor.slice_spec = slice_spec + else: + task1 = device_lib.DeviceSpec.from_string(task_tensor.device).task + task2 = device_lib.DeviceSpec.from_string(shard_tensor.device).task + if task1 is not None and task2 is not None and task1 != task2: + raise RuntimeError( + "After executing the checkpoint sharding callback, tensors " + "with different tasks were found in the same shard:\n" + f" callback_description: {callback_description}\n" + " tensor #1:" + f" checkpoint_key: {task_tensor.checkpoint_key}\n" + f" slice_spec: {task_tensor.slice_spec}\n" + f" task: {task1}\n" + " tensor #2:" + f" checkpoint_key: {checkpoint_key}\n" + f" slice_spec: {slice_spec}\n" + f" task: {task2}\n") + + del unseen_tensor_dict[checkpoint_key][slice_spec] + if not unseen_tensor_dict[checkpoint_key]: + del unseen_tensor_dict[checkpoint_key] + seen_tensor_set.add((checkpoint_key, slice_spec)) + + # validate no tensor removal + if unseen_tensor_dict: + tensors_info = "" + for ckpt_key, slice_spec in unseen_tensor_dict.items(): + tensors_info += " tensor:\n" + tensors_info += f" checkpoint_key: {ckpt_key}\n" + tensors_info += f" slice_spec: {slice_spec}\n" + raise RuntimeError( + "After executing the checkpoint sharding callback, tensors in the " + "object graph were not found in the checkpoint shards:\n" + f" callback_description: {callback_description}\n" + f"{tensors_info}") diff --git a/tensorflow/python/checkpoint/sharding/sharding_util_test.py b/tensorflow/python/checkpoint/sharding/sharding_util_test.py new file mode 100644 index 00000000000000..1c5acbea791b78 --- /dev/null +++ b/tensorflow/python/checkpoint/sharding/sharding_util_test.py @@ -0,0 +1,382 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for checkpoint sharding structures and utilities.""" + + +from typing import Sequence + +from tensorflow.python.checkpoint import checkpoint +from tensorflow.python.checkpoint import graph_view +from tensorflow.python.checkpoint.sharding import sharding_policies +from tensorflow.python.checkpoint.sharding import sharding_util +from tensorflow.python.eager import remote +from tensorflow.python.eager import test +from tensorflow.python.framework import device as device_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.module import module +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import server_lib +from tensorflow.python.training.saving import saveable_object +from tensorflow.python.training.saving import saveable_object_util + + +class ShardingUtilTest(test.TestCase): + + def _get_shardable_tensors_by_task(self, root): + serialized_tensors, _, _, _ = ( + checkpoint.TrackableSaver(graph_view.ObjectGraphView(root)) + ._gather_serialized_tensors(None)) + + shardable_tensors_by_task = {} + for obj, tensor_dict in serialized_tensors.items(): + for checkpoint_key, tensor_slice_dict in tensor_dict.items(): + if not isinstance(tensor_slice_dict, dict): + # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. + tensor_slice_dict = {"": tensor_slice_dict} + for slice_spec, tensor_save_spec in tensor_slice_dict.items(): + if not isinstance(tensor_save_spec, saveable_object.SaveSpec): + tensor_save_spec = saveable_object.SaveSpec( + tensor=tensor_save_spec, + slice_spec=slice_spec, + name=checkpoint_key, + dtype=tensor_save_spec.dtype, + device=tensor_save_spec.device) + save_spec_tensor = tensor_save_spec.tensor + device = (device_lib.DeviceSpec.from_string(tensor_save_spec.device) + if isinstance(tensor_save_spec.device, str) + else tensor_save_spec.device) + task = device_lib.DeviceSpec.from_string( + saveable_object_util.set_cpu0(device.to_string())) + shardable_tensors_by_task.setdefault(task, []).append( + sharding_util.ShardableTensor( + _tensor_save_spec=tensor_save_spec, + tensor=save_spec_tensor, + dtype=tensor_save_spec.dtype, + device=device, + name=tensor_save_spec.name, + shape=save_spec_tensor.shape, + slice_spec=slice_spec.strip(), + checkpoint_key=checkpoint_key, + trackable=obj)) + return shardable_tensors_by_task.values() + + def test_hash_ShardingCallback(self): + class BlankCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + pass + + self.assertEqual(hash(BlankCallback()), hash(BlankCallback())) + + class ValueCallback(sharding_util.ShardingCallback): + def __init__(self, val): + self.val = val + + @property + def description(self): + return "value callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + pass + + self.assertEqual(hash(ValueCallback(1)), hash(ValueCallback(1))) + self.assertNotEqual(hash(ValueCallback(1)), hash(ValueCallback(2))) + + def test_validate_shards_correct(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + with ops.device("cpu:1"): + v1 = resource_variable_ops.ResourceVariable(1.0, name="v1") + with ops.device("cpu:2"): + v2 = resource_variable_ops.ResourceVariable(2.0, name="v2") + root.v0 = v0 + root.v1 = v1 + root.v2 = v2 + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = sharding_policies.ShardByTaskPolicy() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + self.assertEqual( + [list(shard.keys()) for shard in shards], + [[ + "v0/.ATTRIBUTES/VARIABLE_VALUE", + "v1/.ATTRIBUTES/VARIABLE_VALUE", + "v2/.ATTRIBUTES/VARIABLE_VALUE", + "_CHECKPOINTABLE_OBJECT_GRAPH" + ]]) + + self.assertEqual( + shards[0]["v0/.ATTRIBUTES/VARIABLE_VALUE"][""].numpy(), + v0.numpy()) + self.assertEqual( + shards[0]["v1/.ATTRIBUTES/VARIABLE_VALUE"][""].numpy(), + v1.numpy()) + self.assertEqual( + shards[0]["v2/.ATTRIBUTES/VARIABLE_VALUE"][""].numpy(), + v2.numpy()) + + def test_validate_shards_duplicate_tensor(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + with ops.device("cpu:1"): + v1 = resource_variable_ops.ResourceVariable(1.0, name="v1") + root.v0 = v0 + root.v1 = v1 + + class DuplicateTensorCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "duplicate tensor callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + tensor = shardable_tensors[0].tensor + checkpoint_key = shardable_tensors[0].checkpoint_key + slice_spec = shardable_tensors[0].slice_spec + shards = [ + {checkpoint_key: {slice_spec: tensor}}, + {checkpoint_key: {slice_spec: tensor}} + ] + return shards + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = DuplicateTensorCallback() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + with self.assertRaisesRegex(RuntimeError, + "multiple tensors with the same checkpoint " + "key and slice spec were found"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + def test_validate_shards_added_tensor(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + root.v0 = v0 + + class AddedTensorCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "added tensor callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + checkpoint_key = "ADDED_TENSOR_ABC123" + slice_spec = "" + tensor = tensor_lib.Tensor() + return [{checkpoint_key: {slice_spec: tensor}}] + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = AddedTensorCallback() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + with self.assertRaisesRegex(RuntimeError, + "a tensor not originally in the object graph"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + def test_validate_shards_shape_change(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable([[0.0, 1.0]], name="v0") + root.v0 = v0 + + class ShapeChangeCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "shape change callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + shards = [] + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": + tensor = array_ops.transpose(tensor) + shards.append({checkpoint_key: {slice_spec: tensor}}) + return shards + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = ShapeChangeCallback() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + with self.assertRaisesRegex(RuntimeError, + "a tensor was found with an altered shape"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + def test_validate_shards_dtype_change(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + root.v0 = v0 + + class DtypeChangeCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "dtype change callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + shards = [] + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + if checkpoint_key == "v0/.ATTRIBUTES/VARIABLE_VALUE": + tensor = math_ops.cast(tensor, dtype=dtypes.int32) + shards.append({checkpoint_key: {slice_spec: tensor}}) + return shards + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = DtypeChangeCallback() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + with self.assertRaisesRegex(RuntimeError, + "a tensor was found with an altered dtype"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + def test_validate_shards_different_tasks(self): + servers = [server_lib.Server.create_local_server() for _ in range(3)] + cluster_spec = server_lib.ClusterSpec({ + "worker": [s.target[len("grpc://"):] for s in servers]}) + remote.connect_to_cluster(cluster_spec) + + root = module.Module() + with ops.device("/job:worker/task:0/cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + with ops.device("/job:worker/task:1/cpu:0"): + v1 = resource_variable_ops.ResourceVariable(0.0, name="v1") + root.v0 = v0 + root.v1 = v1 + + class DifferentTasksCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "different tasks callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + shard = {} + for shardable_tensor in shardable_tensors: + tensor = shardable_tensor.tensor + checkpoint_key = shardable_tensor.checkpoint_key + slice_spec = shardable_tensor.slice_spec + shard.setdefault(checkpoint_key, {})[slice_spec] = tensor + return [shard] + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = DifferentTasksCallback() + shards = sharding_callback(shardable_tensors_flat) + + with self.assertRaisesRegex(RuntimeError, + "tensors with different tasks were found"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + def test_validate_shards_tensor_removal(self): + root = module.Module() + with ops.device("cpu:0"): + v0 = resource_variable_ops.ResourceVariable(0.0, name="v0") + root.v0 = v0 + + class TensorRemovalCallback(sharding_util.ShardingCallback): + @property + def description(self): + return "tensor removal callback" + + def __call__( + self, shardable_tensors: Sequence[sharding_util.ShardableTensor] + ) -> Sequence[sharding_util.TensorSliceDict]: + return [] + + shardable_tensors = self._get_shardable_tensors_by_task(root) + shardable_tensors_flat = [] + for tensors in shardable_tensors: + shardable_tensors_flat.extend(tensors) + + sharding_callback = TensorRemovalCallback() + shards = [] + for tensors in shardable_tensors: + shards.extend(sharding_callback(tensors)) + + with self.assertRaisesRegex(RuntimeError, + "tensors in the object graph were not found"): + sharding_util.validate_shards( + shards, shardable_tensors_flat, sharding_callback.description) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 782ffe58b43059..1f8f7e6b8b1d31 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -3,6 +3,10 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_cuda_library") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config_root.bzl", "if_static") +load( + "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -108,6 +112,10 @@ tf_python_pybind_extension( tf_python_pybind_extension( name = "_pywrap_events_writer", srcs = ["events_writer_wrapper.cc"], + enable_stub_generation = True, + pytype_srcs = [ + "_pywrap_events_writer.pyi", + ], deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", @@ -433,6 +441,7 @@ tf_py_strict_test( python_version = "PY3", tags = [ "no_gpu", + "no_rocm", "no_windows", ], deps = [ @@ -494,7 +503,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "session_benchmark", srcs = ["session_benchmark.py"], grpc_enabled = True, diff --git a/tensorflow/python/tpu/tpu_config.py b/tensorflow/python/client/_pywrap_events_writer.pyi similarity index 52% rename from tensorflow/python/tpu/tpu_config.py rename to tensorflow/python/client/_pywrap_events_writer.pyi index eda3717520f7a8..92da35bcfe093b 100644 --- a/tensorflow/python/tpu/tpu_config.py +++ b/tensorflow/python/client/_pywrap_events_writer.pyi @@ -1,10 +1,10 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,8 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Stub file to maintain backwards compatibility.""" -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.tpu_config import * -# pylint: enable=wildcard-import,unused-import +class EventsWriter: + def __init__(self, arg0: str) -> None: ... + def Close(self) -> Status: ... + def FileName(self) -> str: ... + def Flush(self) -> Status: ... + def InitWithSuffix(self, arg0: str) -> Status: ... + def WriteEvent(self, arg0: object) -> None: ... + def _WriteSerializedEvent(self, arg0: str) -> None: ... + +class Status: + def __init__(self, *args, **kwargs) -> None: ... diff --git a/tensorflow/python/client/events_writer_wrapper.cc b/tensorflow/python/client/events_writer_wrapper.cc index 7e5720c4eef02d..661c845b3aac57 100644 --- a/tensorflow/python/client/events_writer_wrapper.cc +++ b/tensorflow/python/client/events_writer_wrapper.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "absl/strings/string_view.h" +#include "pybind11/attr.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/events_writer.h" #include "tensorflow/python/lib/core/pybind11_absl.h" #include "tensorflow/python/lib/core/pybind11_proto.h" @@ -24,6 +26,7 @@ limitations under the License. namespace py = pybind11; PYBIND11_MODULE(_pywrap_events_writer, m) { + py::class_ Status(m, "Status", py::module_local()); py::class_ events_writer_class(m, "EventsWriter"); events_writer_class.def(py::init()) .def("InitWithSuffix", diff --git a/tensorflow/python/client/session_partial_run_test.py b/tensorflow/python/client/session_partial_run_test.py index 075d69e78bc400..79cedb5a2ffdd6 100644 --- a/tensorflow/python/client/session_partial_run_test.py +++ b/tensorflow/python/client/session_partial_run_test.py @@ -26,7 +26,6 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import server_lib - class PartialRunTest(test_util.TensorFlowTestCase): def RunTestPartialRun(self, sess): diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index 790629c96d2e4f..160416c4199102 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -138,9 +138,9 @@ pybind11::object method(pybind11::object type, Func&& function, // generation. The type is assumed to be a GC type (containing other types). // To add the required Python type fields, classes definitions must start with // -// TFObject_Head(classname) +// TFObject_Head(classname, TfObjectDataType) // -// Required attributes/methods: +// Required attributes/methods for TfObjectDataType type: // // Constructor(PyObject* args, PyObject* kw) // ~Destructor @@ -148,8 +148,10 @@ pybind11::object method(pybind11::object type, Func&& function, // Visit(visitproc visit, void* arg) // // Individual methods/attributes are added to the type later, as seen below. -template +template void MakeTfObjectType(PyObject** py_type) { + using TfObjectDataType = typename T::TfObjectDataType; + py::str name = py::str(T::kTypeName); py::str qualname = py::str(T::kTypeName); PyHeapTypeObject* heap_type = reinterpret_cast( @@ -162,11 +164,14 @@ void MakeTfObjectType(PyObject** py_type) { type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE; type->tp_name = T::kTypeName; - type->tp_basicsize = sizeof(T); + + // Allocation size for both Python object header and the TF data members. + type->tp_basicsize = sizeof(T) + sizeof(TfObjectDataType); type->tp_new = [](PyTypeObject* subtype, PyObject* args, PyObject* kwds) -> PyObject* { T* self = reinterpret_cast(subtype->tp_alloc(subtype, 0)); + TfObjectDataType* data = reinterpret_cast(&self[1]); if (!self) return nullptr; // PyType_GenericAlloc (the default implementation of tp_alloc) by default @@ -176,7 +181,7 @@ void MakeTfObjectType(PyObject** py_type) { // // We disable the GC here until initialization is finished. PyObject_GC_UnTrack(self); - new (self) T(args, kwds); + new (data) TfObjectDataType(args, kwds); self->dict = PyDict_New(); PyObject_GC_Track(self); @@ -193,9 +198,9 @@ void MakeTfObjectType(PyObject** py_type) { PyObject_ClearWeakRefs(self); T* o = reinterpret_cast(self); + TfObjectDataType* data = reinterpret_cast(&o[1]); Py_CLEAR(o->dict); - o->~T(); - + data->~TfObjectDataType(); tp->tp_free(self); Py_DECREF(tp); }; @@ -203,16 +208,18 @@ void MakeTfObjectType(PyObject** py_type) { type->tp_traverse = [](PyObject* self, visitproc visit, void* arg) { VLOG(3) << "Visit: " << T::kTypeName; T* o = reinterpret_cast(self); + TfObjectDataType* data = reinterpret_cast(&o[1]); Py_VISIT(Py_TYPE(self)); Py_VISIT(o->dict); - return o->Visit(visit, arg); + return data->Visit(visit, arg); }; type->tp_clear = [](PyObject* self) { VLOG(3) << "Clear: " << T::kTypeName; T* o = reinterpret_cast(self); + TfObjectDataType* data = reinterpret_cast(&o[1]); Py_CLEAR(o->dict); - o->Clear(); + data->Clear(); return 0; }; @@ -238,11 +245,13 @@ void MakeTfObjectType(PyObject** py_type) { *py_type = reinterpret_cast(type); } -#define TFObject_HEAD(typename) \ - PyObject_HEAD; \ - PyObject* dict = nullptr; \ - PyObject* weakrefs = nullptr; \ - static PyObject* py_type; \ +#define TFObject_HEAD(typename, datatypename) \ + using TfObjectDataType = datatypename; \ + PyObject_HEAD; \ + PyObject* dict = nullptr; \ + PyObject* weakrefs = nullptr; \ + TfObjectDataType data[0]; \ + static PyObject* py_type; \ static constexpr const char* kTypeName = #typename; struct PyGraph; @@ -272,7 +281,7 @@ PYBIND11_MAKE_OPAQUE(OpsByIdMap); PYBIND11_MAKE_OPAQUE(OpsByNameMap); // Convert the given handle to a TF object type. -template +template T* AsPyTfObject(py::handle handle) { if (handle.get_type() == T::py_type) { return reinterpret_cast(handle.ptr()); @@ -296,11 +305,15 @@ T* AsPyTfObject(py::handle handle) { py::cast(py::str(handle)))); } -template +template py::object AsPyObject(T* obj) { return py::reinterpret_borrow(reinterpret_cast(obj)); } +template +typename T::TfObjectDataType* AsPyTfObjectData(py::handle handle) { + return AsPyTfObject(handle)->data; +} // Reference counting helper for PyTfObjects. // // Similar to the pybind holder types, this manages the Python reference @@ -309,7 +322,7 @@ py::object AsPyObject(T* obj) { // As a special case to support Dismantle(), this allows setting our underlying // pointer to None when clearing the type. Direct access to attributes is not // allowed after this point. -template +template class tf_handle { public: tf_handle() : obj_(nullptr) {} @@ -402,9 +415,7 @@ struct TF_OperationDeleter { void operator()(TF_Operation* op) {} }; -struct PyGraph { - TFObject_HEAD(PyGraph); - +struct PyGraphData { TF_Graph* graph; // The C++ graph maintains an ID for every node, however our Python code has @@ -424,7 +435,7 @@ struct PyGraph { OpsByIdMap ops_by_id; OpsByNameMap ops_by_name; - PyGraph(PyObject* args, PyObject* kwds) { + PyGraphData(PyObject* args, PyObject* kwds) { graph = TF_NewGraph(); // By default shape inference functions are required, however this breaks @@ -433,7 +444,7 @@ struct PyGraph { graph->refiner.set_require_shape_inference_fns(false); } - ~PyGraph() { + ~PyGraphData() { Clear(); TF_DeleteGraph(graph); } @@ -462,22 +473,26 @@ struct PyGraph { } return 0; } +}; + +struct PyGraph { + TFObject_HEAD(PyGraph, PyGraphData); int64_t add_op(py::object obj); - py::list operations() { return op_list; } - int64_t num_operations() const { return op_list.size(); } + py::list operations() { return data->op_list; } + int64_t num_operations() const { return data->op_list.size(); } // Return operations that are part of the Graph, but do not yet have // OperationHandle's. This logic is only invoked when importing an existing // GraphDef into Python. It should be removed once all logic moves to C++. std::vector new_operations() { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); std::vector ops; // SUBTLE: `op_nodes` skips the SOURCE and SINK nodes - for (auto n : graph->graph.op_nodes()) { - if (ops_by_name.find(n->name()) == ops_by_name.end()) { + for (auto n : tf_graph()->graph.op_nodes()) { + if (data->ops_by_name.find(n->name()) == data->ops_by_name.end()) { ops.push_back(reinterpret_cast(n)); } } @@ -485,15 +500,15 @@ struct PyGraph { } py::object get_operation_by_name(const std::string& name) { - tsl::mutex_lock l(graph->mu); - auto it = ops_by_name.find(name); - if (it == ops_by_name.end()) { + tsl::mutex_lock l(tf_graph()->mu); + auto it = data->ops_by_name.find(name); + if (it == data->ops_by_name.end()) { throw py::key_error(); } return it->second; } - int version() const { return ops_by_id.size(); } + int version() const { return data->ops_by_id.size(); } py::bytes version_def() const { // Potential deadlock: @@ -509,8 +524,8 @@ struct PyGraph { std::string versions; { py::gil_scoped_release release; - tsl::mutex_lock l(graph->mu); - versions = graph->graph.versions().SerializeAsString(); + tsl::mutex_lock l(tf_graph()->mu); + versions = tf_graph()->graph.versions().SerializeAsString(); } pybind11::gil_scoped_acquire acquire; return py::bytes(versions); @@ -518,52 +533,52 @@ struct PyGraph { tsl::StatusOr _op_def_for_type( const std::string& kTypeName) const { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); const tensorflow::OpDef* op_def; TF_RETURN_IF_ERROR( - graph->graph.op_registry()->LookUpOpDef(kTypeName, &op_def)); + tf_graph()->graph.op_registry()->LookUpOpDef(kTypeName, &op_def)); return py::bytes(op_def->SerializeAsString()); } void add_control_input(tensorflow::Node* src, tensorflow::Node* dst) { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); - graph->graph.AddControlEdge(src, dst); + tf_graph()->graph.AddControlEdge(src, dst); record_mutation(*dst, "adding control edge"); } void remove_all_control_inputs(const tensorflow::Node& node) { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); std::vector control_edges; for (const tensorflow::Edge* edge : node.in_edges()) { if (!edge->IsControlEdge()) continue; control_edges.push_back(edge); } for (const tensorflow::Edge* edge : control_edges) { - graph->graph.RemoveControlEdge(edge); + tf_graph()->graph.RemoveControlEdge(edge); } } void record_mutation(const tensorflow::Node& node, const std::string& reason) - TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - tensorflow::RecordMutation( - graph, reinterpret_cast(node), reason.c_str()); + TF_EXCLUSIVE_LOCKS_REQUIRED(tf_graph()->mu) { + tensorflow::RecordMutation(tf_graph(), + reinterpret_cast(node), + reason.c_str()); } - TF_Graph* tf_graph() { return graph; } + TF_Graph* tf_graph() const { return data->graph; } }; -struct PyOperation { - TFObject_HEAD(PyOperation); - +struct PyOperationData { TF_Operation* tf_op = nullptr; + py::list outputs; // N.B. initialized later by Python. tf_handle graph; py::function tensor_fn; - PyOperation(PyObject* args, PyObject* kwds) { + PyOperationData(PyObject* args, PyObject* kwds) { PyObject *py_op, *py_tensor_fn; if (!PyArg_ParseTuple(args, "OO", &py_op, &py_tensor_fn)) { return; @@ -572,90 +587,92 @@ struct PyOperation { tensor_fn = py::cast(py_tensor_fn); } - ~PyOperation() { Clear(); } + ~PyOperationData() { Clear(); } + + void Dismantle(PyOperation* py_op); void Clear() { Py_CLEAR(outputs.release().ptr()); graph.Clear(); } - void Dismantle(); - int Visit(visitproc visit, void* arg) { Py_VISIT(graph.ptr()); Py_VISIT(outputs.ptr()); return 0; } +}; + +struct PyOperation { + TFObject_HEAD(PyOperation, PyOperationData); + + TF_Operation* tf_op() const { return data->tf_op; } void _init_outputs() { - int num_outputs = TF_OperationNumOutputs(tf_op); + int num_outputs = TF_OperationNumOutputs(tf_op()); for (int i = 0; i < num_outputs; ++i) { - auto dtype = TF_OperationOutputType(TF_Output{tf_op, i}); - outputs.append(tensor_fn(AsPyObject(this), i, dtype)); + auto dtype = TF_OperationOutputType(TF_Output{tf_op(), i}); + data->outputs.append(data->tensor_fn(AsPyObject(this), i, dtype)); } } tsl::Status _add_outputs(py::list dtypes, py::list shapes); - const TF_Operation* op() { return tf_op; } - - TF_Output _tf_output(int idx) const { return TF_Output{tf_op, idx}; } - TF_Input _tf_input(int idx) const { return TF_Input{tf_op, idx}; } + TF_Output _tf_output(int idx) const { return TF_Output{tf_op(), idx}; } + TF_Input _tf_input(int idx) const { return TF_Input{tf_op(), idx}; } py::bytes node_def() { - return py::bytes(tf_op->node.def().SerializeAsString()); + return py::bytes(tf_op()->node.def().SerializeAsString()); } py::bytes op_def() const { - return py::bytes(tf_op->node.op_def().SerializeAsString()); + return py::bytes(tf_op()->node.op_def().SerializeAsString()); } - bool is_stateful() const { return tf_op->node.op_def().is_stateful(); } + bool is_stateful() const { return tf_op()->node.op_def().is_stateful(); } - const std::string& type() { return tf_op->node.type_string(); } + const std::string& type() { return tf_op()->node.type_string(); } void add_control_input(PyOperation* input) { - graph->add_control_input(&input->tf_op->node, &tf_op->node); + data->graph->add_control_input(&input->tf_op()->node, &tf_op()->node); } void add_control_inputs(py::iterable inputs); py::list control_inputs() { py::list output; - for (const auto* edge : tf_op->node.in_edges()) { + for (const auto* edge : tf_op()->node.in_edges()) { if (edge->IsControlEdge() && !edge->src()->IsSource()) { - output.append(graph->ops_by_id[edge->src()->id()]); + output.append(data->graph->data->ops_by_id[edge->src()->id()]); } } return output; } py::list control_outputs() { py::list output; - for (const auto* edge : tf_op->node.out_edges()) { + for (const auto* edge : tf_op()->node.out_edges()) { if (edge->IsControlEdge() && !edge->dst()->IsSink()) { - output.append(graph->ops_by_id[edge->dst()->id()]); + output.append(data->graph->data->ops_by_id[edge->dst()->id()]); } } return output; } void remove_all_control_inputs() { - graph->remove_all_control_inputs(tf_op->node); + data->graph->remove_all_control_inputs(tf_op()->node); } void set_device(const std::string& device) { - tsl::mutex_lock l(graph->graph->mu); - tf_op->node.set_requested_device(device); - graph->record_mutation(tf_op->node, "setting device"); + tsl::mutex_lock l(data->graph->tf_graph()->mu); + tf_op()->node.set_requested_device(device); + data->graph->record_mutation(tf_op()->node, "setting device"); } - const std::string& device() { return tf_op->node.requested_device(); } - const std::string& name() { return tf_op->node.name(); } + const std::string& device() { return tf_op()->node.requested_device(); } + const std::string& name() { return tf_op()->node.name(); } }; -struct PyTensor { - TFObject_HEAD(PyTensor); - +struct PyTensorData { py::object tf_output = py::none(); py::object name = py::none(); py::object dtype = py::none(); @@ -667,7 +684,7 @@ struct PyTensor { int value_index = -1; - PyTensor(PyObject* args, PyObject* kwds) { + PyTensorData(PyObject* args, PyObject* kwds) { PyObject *py_op, *py_index, *py_dtype, *py_uid; if (!PyArg_ParseTuple(args, "OOOO", &py_op, &py_index, &py_dtype, &py_uid)) { @@ -676,12 +693,13 @@ struct PyTensor { dtype = py::reinterpret_borrow(py_dtype); value_index = py::cast(py::handle(py_index)); op = py_op; - graph = op->graph; + graph = op->data->graph; name = py::str(absl::StrCat(op->name(), ":", value_index)); - tf_output = py::cast(TF_Output{op->tf_op, value_index}); + tf_output = py::cast(TF_Output{op->tf_op(), value_index}); uid = py::reinterpret_borrow(py_uid); } - ~PyTensor() { Clear(); } + + ~PyTensorData() { Clear(); } void Clear() { Py_CLEAR(tf_output.release().ptr()); @@ -703,14 +721,20 @@ struct PyTensor { Py_VISIT(uid.ptr()); return 0; } +}; + +struct PyTensor { + TFObject_HEAD(PyTensor, PyTensorData); + + int value_index() const { return data->value_index; } tsl::StatusOr shape() { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); bool unknown_shape = false; auto dims = tensorflow::TF_GraphGetTensorShapeHelper( - graph->tf_graph(), TF_Output{op->tf_op, value_index}, status.get(), - &unknown_shape); + data->graph->tf_graph(), TF_Output{data->op->tf_op(), value_index()}, + status.get(), &unknown_shape); if (!status.get()->status.ok()) { return status.get()->status; } @@ -737,17 +761,17 @@ struct PyTensor { } } tensorflow::TF_GraphSetTensorShape_wrapper( - graph->tf_graph(), TF_Output{op->tf_op, value_index}, dims, - unknown_shape, status.get()); + data->graph->tf_graph(), TF_Output{data->op->tf_op(), value_index()}, + dims, unknown_shape, status.get()); return status.get()->status; } int64_t rank() { - tsl::mutex_lock l(graph->graph->mu); + tsl::mutex_lock l(data->graph->tf_graph()->mu); tensorflow::shape_inference::InferenceContext* ic = - graph->graph->refiner.GetContext(&op->tf_op->node); + data->graph->tf_graph()->refiner.GetContext(&data->op->tf_op()->node); - tensorflow::shape_inference::ShapeHandle shape = ic->output(value_index); + tensorflow::shape_inference::ShapeHandle shape = ic->output(value_index()); if (ic->RankKnown(shape)) { return ic->Rank(shape); } @@ -756,11 +780,11 @@ struct PyTensor { py::list consumers() { py::list out; - for (const auto* edge : op->tf_op->node.out_edges()) { - if (edge->src_output() != value_index) { + for (const auto* edge : data->op->tf_op()->node.out_edges()) { + if (edge->src_output() != value_index()) { continue; } - out.append(graph->ops_by_id[edge->dst()->id()]); + out.append(data->graph->data->ops_by_id[edge->dst()->id()]); } return out; } @@ -770,17 +794,17 @@ PyObject* PyOperation::py_type = nullptr; PyObject* PyTensor::py_type = nullptr; PyObject* PyGraph::py_type = nullptr; -void PyOperation::Dismantle() { +void PyOperationData::Dismantle(PyOperation* py_op) { outputs = py::list(); - PyDict_Clear(dict); graph.Destroy(); + PyDict_Clear(py_op->dict); } tsl::Status PyOperation::_add_outputs(py::list dtypes, py::list shapes) { - int orig_outputs = outputs.size(); + int orig_outputs = data->outputs.size(); for (int i = 0; i < dtypes.size(); ++i) { py::object tensor = - tensor_fn(AsPyObject(this), orig_outputs + i, dtypes[i]); + data->tensor_fn(AsPyObject(this), orig_outputs + i, dtypes[i]); // The passed in `shapes` may be TensorShapes, convert them to lists if // needed. @@ -799,24 +823,25 @@ tsl::Status PyOperation::_add_outputs(py::list dtypes, py::list shapes) { } TF_RETURN_IF_ERROR( AsPyTfObject(tensor)->set_shape(dims, unknown_shape)); - outputs.append(tensor); + data->outputs.append(tensor); } return tsl::OkStatus(); } void PyOperation::add_control_inputs(py::iterable inputs) { - tsl::mutex_lock l(graph->tf_graph()->mu); + tsl::mutex_lock l(data->graph->tf_graph()->mu); for (py::handle input : inputs) { auto* input_handle = py::cast(input); - graph->tf_graph()->graph.AddControlEdge(&input_handle->tf_op->node, - &tf_op->node); + data->graph->tf_graph()->graph.AddControlEdge(&input_handle->tf_op()->node, + &tf_op()->node); } - graph->record_mutation(tf_op->node, "adding control input"); + data->graph->record_mutation(tf_op()->node, "adding control input"); } -void PyGraph::Dismantle() { +void PyGraphData::Dismantle() { for (auto& op : op_list) { - AsPyTfObject(op.ptr())->Dismantle(); + AsPyTfObjectData(op.ptr())->Dismantle( + AsPyTfObject(op.ptr())); } op_list = py::list(); ops_by_id.clear(); @@ -825,10 +850,10 @@ void PyGraph::Dismantle() { int64_t PyGraph::add_op(py::object obj) { PyOperation* op_handle = AsPyTfObject(obj); - int64_t op_id = op_handle->tf_op->node.id(); - op_list.append(obj); - ops_by_id[op_id] = obj; - ops_by_name[op_handle->name()] = obj; + int64_t op_id = op_handle->tf_op()->node.id(); + data->op_list.append(obj); + data->ops_by_id[op_id] = obj; + data->ops_by_name[op_handle->name()] = obj; return op_id; } @@ -848,7 +873,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { m.attr("PyGraph") = c_graph; c_graph.attr("__module__") = module_name; c_graph.attr("Dismantle") = method(c_graph, [](py::handle handle) { - AsPyTfObject(handle)->Dismantle(); + AsPyTfObjectData(handle)->Dismantle(); }); c_graph.attr("_version_def") = property_readonly([](py::handle handle) { return AsPyTfObject(handle)->version_def(); @@ -861,10 +886,10 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->_op_def_for_type(type); }); c_graph.attr("_nodes_by_name") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->ops_by_name; + return AsPyTfObjectData(handle)->ops_by_name; }); c_graph.attr("_nodes_by_id") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->ops_by_id; + return AsPyTfObjectData(handle)->ops_by_id; }); c_graph.attr("_get_operation_by_name") = method(c_graph, [](py::handle handle, std::string name) { @@ -919,18 +944,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->remove_all_control_inputs(); }); c_op.attr("outputs") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->outputs; + return AsPyTfObjectData(handle)->outputs; }); c_op.attr("graph") = property( [](py::handle handle) { - return AsPyTfObject(handle)->graph.borrow(); + return AsPyTfObjectData(handle)->graph.borrow(); }, [](py::handle handle, py::handle graph) { auto op = AsPyTfObject(handle); - op->graph = graph.ptr(); + op->data->graph = graph.ptr(); }); c_op.attr("_c_op") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->tf_op; + return AsPyTfObject(handle)->tf_op(); }); c_op.attr("_is_stateful") = property_readonly([](py::handle handle) { return AsPyTfObject(handle)->is_stateful(); @@ -983,7 +1008,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { m.attr("PyTensor") = c_tensor; c_tensor.attr("__module__") = module_name; c_tensor.attr("device") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->op->device(); + return AsPyTfObjectData(handle)->op->device(); }); c_tensor.attr("ndim") = property_readonly([](py::handle handle) { return AsPyTfObject(handle)->rank(); @@ -995,40 +1020,44 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->shape(); }); c_tensor.attr("_dtype") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->dtype; + return AsPyTfObjectData(handle)->dtype; }); c_tensor.attr("_name") = property( - [](py::handle handle) { return AsPyTfObject(handle)->name; }, + [](py::handle handle) { + return AsPyTfObjectData(handle)->name; + }, [](py::handle handle, py::object name) { - AsPyTfObject(handle)->name = name; + AsPyTfObjectData(handle)->name = name; }); c_tensor.attr("_shape_val") = property( [](py::handle handle) { auto py_tensor = AsPyTfObject(handle); - return py_tensor->shape_val; + return py_tensor->data->shape_val; }, [](py::handle handle, py::object shape) { - AsPyTfObject(handle)->shape_val = shape; + AsPyTfObjectData(handle)->shape_val = shape; }); c_tensor.attr("_id") = property( - [](py::handle handle) { return AsPyTfObject(handle)->uid; }, + [](py::handle handle) { + return AsPyTfObjectData(handle)->uid; + }, [](py::handle handle, py::object uid) { - AsPyTfObject(handle)->uid = uid; + AsPyTfObjectData(handle)->uid = uid; }); c_tensor.attr("graph") = property_readonly([](py::handle handle) -> py::handle { - auto& graph = AsPyTfObject(handle)->graph; + auto& graph = AsPyTfObjectData(handle)->graph; if (graph.ptr() != nullptr) { return graph.borrow(); } return py::none(); }); c_tensor.attr("_as_tf_output") = method(c_tensor, [](py::handle handle) { - return AsPyTfObject(handle)->tf_output; + return AsPyTfObjectData(handle)->tf_output; }); c_tensor.attr("_op") = property_readonly([](py::handle handle) -> py::handle { - auto& op = AsPyTfObject(handle)->op; + auto& op = AsPyTfObjectData(handle)->op; if (op.ptr() != nullptr) { return op.borrow(); } @@ -1036,7 +1065,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { }); c_tensor.attr("op") = property_readonly([](py::handle handle) -> py::handle { - auto& op = AsPyTfObject(handle)->op; + auto& op = AsPyTfObjectData(handle)->op; if (op.ptr() != nullptr) { return op.borrow(); } @@ -1048,7 +1077,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->set_shape(shape, unknown_shape); }); c_tensor.attr("value_index") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->value_index; + return AsPyTfObject(handle)->value_index(); }); c_tensor.attr("consumers") = method(c_tensor, [](py::handle handle) { return AsPyTfObject(handle)->consumers(); diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 68bab012e8bf28..8765961c533f7c 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -13,14 +13,9 @@ py_strict_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:tf2", - "//tensorflow/python/data/experimental/ops:counter", - "//tensorflow/python/data/experimental/ops:interleave_ops", - "//tensorflow/python/data/experimental/ops:random_ops", - "//tensorflow/python/data/experimental/ops:readers", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:monitoring", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:registry", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:control_flow_v2_toggles", diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 73e0ad94e2e434..fd9132a7448210 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 6) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 12, 14) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py index cef625b1dc355a..5820e477eb2e5f 100644 --- a/tensorflow/python/compat/v2_compat.py +++ b/tensorflow/python/compat/v2_compat.py @@ -15,19 +15,13 @@ """Switching v2 features on and off.""" from tensorflow.python import tf2 -from tensorflow.python.data.experimental.ops import counter -from tensorflow.python.data.experimental.ops import interleave_ops -from tensorflow.python.data.experimental.ops import random_ops -from tensorflow.python.data.experimental.ops import readers as exp_readers -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers from tensorflow.python.eager import monitoring from tensorflow.python.framework import ops +from tensorflow.python.framework import registry from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import resource_variables_toggle - from tensorflow.python.util.tf_export import tf_export # Metrics to track the status of v2_behavior @@ -35,6 +29,12 @@ "/tensorflow/version/v2_behavior", "whether v2_behavior is enabled or disabled", "status") +_DATA_V2_CALLBACKS = registry.Registry("data_v2_callbacks") + + +def register_data_v2_callback(data_v2_func): + _DATA_V2_CALLBACKS.register(data_v2_func, data_v2_func.__module__) + @tf_export(v1=["enable_v2_behavior"]) def enable_v2_behavior(): @@ -65,19 +65,9 @@ def enable_v2_behavior(): # Enables TensorArrayV2 and control flow V2. control_flow_v2_toggles.enable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V2 versions. - dataset_ops.Dataset = dataset_ops.DatasetV2 - readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV2 - readers.TFRecordDataset = readers.TFRecordDatasetV2 - readers.TextLineDataset = readers.TextLineDatasetV2 - counter.Counter = counter.CounterV2 - interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v2 - interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v2 - random_ops.RandomDataset = random_ops.RandomDatasetV2 - exp_readers.CsvDataset = exp_readers.CsvDatasetV2 - exp_readers.SqlDataset = exp_readers.SqlDatasetV2 - exp_readers.make_batched_features_dataset = ( - exp_readers.make_batched_features_dataset_v2) - exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v2 + for v2_enabler_name in _DATA_V2_CALLBACKS.list(): + v2_enabler = _DATA_V2_CALLBACKS.lookup(v2_enabler_name) + v2_enabler() @tf_export(v1=["disable_v2_behavior"]) @@ -110,16 +100,6 @@ def disable_v2_behavior(): # Disables TensorArrayV2 and control flow V2. control_flow_v2_toggles.disable_control_flow_v2() # Make sure internal uses of tf.data symbols map to V1 versions. - dataset_ops.Dataset = dataset_ops.DatasetV1 - readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV1 - readers.TFRecordDataset = readers.TFRecordDatasetV1 - readers.TextLineDataset = readers.TextLineDatasetV1 - counter.Counter = counter.CounterV1 - interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v1 - interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v1 - random_ops.RandomDataset = random_ops.RandomDatasetV1 - exp_readers.CsvDataset = exp_readers.CsvDatasetV1 - exp_readers.SqlDataset = exp_readers.SqlDatasetV1 - exp_readers.make_batched_features_dataset = ( - exp_readers.make_batched_features_dataset_v1) - exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v1 + for v2_disabler_name in _DATA_V2_CALLBACKS.list(): + v2_disabler = _DATA_V2_CALLBACKS.lookup(v2_disabler_name) + v2_disabler() diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 9fbdcf56b1a023..0dc49c5a56f3fb 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -32,12 +32,10 @@ py_strict_library( py_strict_library( name = "trt_convert_py", - srcs = [ - "trt_convert.py", - "utils.py", - ], + srcs = ["trt_convert.py"], srcs_version = "PY3", deps = [ + ":utils", "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", "//tensorflow/compiler/tf2tensorrt:trt_ops_loader", "//tensorflow/core:protos_all_py", @@ -69,19 +67,30 @@ py_strict_library( "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", - "@pypi_packaging//:pkg", "@six_archive//:six", ], ) +py_strict_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/framework:dtypes", + "@pypi_packaging//:pkg", + ], +) + py_strict_library( name = "tf_trt_integration_test_base", - srcs = ["//tensorflow/python/compiler/tensorrt/test:tf_trt_integration_test_base_srcs"], srcs_version = "PY3", deps = [ ":trt_convert_py", + ":utils", "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", "//tensorflow/core:protos_all_py", + "//tensorflow/python/compiler/tensorrt/test:tf_trt_integration_test_base_srcs", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:config", "//tensorflow/python/framework:graph_io", @@ -121,6 +130,8 @@ cuda_py_strict_test( "no_oss", "no_pip", "nomac", + # TODO(b/303453873): Re-enable tests once TensorRT has been updated + "notap", ], xla_enable_strict_auto_jit = False, deps = [ diff --git a/tensorflow/python/compiler/tensorrt/test/BUILD b/tensorflow/python/compiler/tensorrt/test/BUILD index 4bea640efd6015..15499cbdf79c39 100644 --- a/tensorflow/python/compiler/tensorrt/test/BUILD +++ b/tensorflow/python/compiler/tensorrt/test/BUILD @@ -30,6 +30,7 @@ py_strict_library( "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", "//tensorflow/core:protos_all_py", "//tensorflow/python/compiler/tensorrt:trt_convert_py", + "//tensorflow/python/compiler/tensorrt:utils", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:config", "//tensorflow/python/framework:graph_io", @@ -93,6 +94,8 @@ base_tags = [ "no_rocm", "no_windows", "nomac", + # TODO(b/303453873): Re-enable tests once TensorRT has been updated + "notap", ] cuda_py_strict_test( @@ -106,7 +109,7 @@ cuda_py_strict_test( xla_enable_strict_auto_jit = False, deps = [ ":tf_trt_integration_test_base_srcs", - "//tensorflow/python/compiler/tensorrt:trt_convert_py", + "//tensorflow/python/compiler/tensorrt:utils", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", @@ -143,7 +146,7 @@ cuda_py_strict_test( xla_enable_strict_auto_jit = False, deps = [ ":tf_trt_integration_test_base_srcs", - "//tensorflow/python/compiler/tensorrt:trt_convert_py", + "//tensorflow/python/compiler/tensorrt:utils", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", @@ -243,7 +246,7 @@ cuda_py_strict_test( xla_enable_strict_auto_jit = False, deps = [ ":tf_trt_integration_test_base_srcs", - "//tensorflow/python/compiler/tensorrt:trt_convert_py", + "//tensorflow/python/compiler/tensorrt:utils", "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/python/data/benchmarks/BUILD b/tensorflow/python/data/benchmarks/BUILD index be65e17f9784e4..d023a58baf0ef8 100644 --- a/tensorflow/python/data/benchmarks/BUILD +++ b/tensorflow/python/data/benchmarks/BUILD @@ -1,5 +1,8 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") +load( + "//tensorflow/tools/test:performance.bzl", + "tf_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -22,7 +25,7 @@ py_strict_library( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "meta_benchmark", srcs = ["meta_benchmark.py"], deps = [ @@ -36,7 +39,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "batch_benchmark", srcs = ["batch_benchmark.py"], deps = [ @@ -49,7 +52,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "filter_benchmark", srcs = ["filter_benchmark.py"], deps = [ @@ -59,7 +62,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "from_tensor_slices_benchmark", srcs = ["from_tensor_slices_benchmark.py"], deps = [ @@ -74,7 +77,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "interleave_benchmark", srcs = ["interleave_benchmark.py"], deps = [ @@ -85,7 +88,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "list_files_benchmark", srcs = ["list_files_benchmark.py"], deps = [ @@ -94,7 +97,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "map_benchmark", srcs = ["map_benchmark.py"], deps = [ @@ -109,7 +112,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "prefetch_benchmark", srcs = ["prefetch_benchmark.py"], deps = [ @@ -118,7 +121,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "range_benchmark", srcs = ["range_benchmark.py"], deps = [ diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD index a9eef9c7ad6e91..e61a15a2ae88f1 100644 --- a/tensorflow/python/data/experimental/benchmarks/BUILD +++ b/tensorflow/python/data/experimental/benchmarks/BUILD @@ -1,5 +1,8 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary") -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") +load( + "//tensorflow/tools/test:performance.bzl", + "tf_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -22,7 +25,7 @@ py_strict_binary( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "autotune_benchmark", srcs = ["autotune_benchmark.py"], deps = [ @@ -34,7 +37,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "parameter_value_benchmark", srcs = ["parameter_value_benchmark.py"], deps = [ @@ -47,7 +50,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "csv_dataset_benchmark", srcs = ["csv_dataset_benchmark.py"], tags = ["no_pip"], @@ -61,7 +64,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "map_and_batch_benchmark", srcs = ["map_and_batch_benchmark.py"], deps = [ @@ -77,7 +80,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "map_defun_benchmark", srcs = ["map_defun_benchmark.py"], deps = [ @@ -92,7 +95,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "matching_files_benchmark", size = "small", srcs = ["matching_files_benchmark.py"], @@ -102,7 +105,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "optimize_benchmark", srcs = ["optimize_benchmark.py"], deps = [ @@ -113,7 +116,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "rejection_resample_benchmark", srcs = ["rejection_resample_benchmark.py"], tags = ["no_pip"], @@ -126,7 +129,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "snapshot_dataset_benchmark", srcs = ["snapshot_dataset_benchmark.py"], deps = [ @@ -138,7 +141,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "unbatch_benchmark", srcs = ["unbatch_benchmark.py"], deps = [ diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 93f59e64bc6ccb..32c958e4185d83 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -90,6 +90,8 @@ tf_py_strict_test( "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:check_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py index 03f795a9212d84..2a5848fb87befb 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_fusion_test.py @@ -24,6 +24,8 @@ from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -110,8 +112,66 @@ def testMapFusion(self, functions, num_parallel_calls, deterministic): r = function(r) expected_output.append(r) - if num_parallel_calls is None or deterministic in [None, True]: - self.assertDatasetProduces(dataset, expected_output=expected_output) + nondeterministic_ordering = ( + num_parallel_calls is not None and deterministic is False # pylint: disable=g-bool-id-comparison + ) + self.assertDatasetProduces( + dataset, + expected_output=expected_output, + assert_items_equal=nondeterministic_ordering, + ) + + @combinations.generate(test_base.default_test_combinations()) + def testMapFusionLongMapChain(self): + n = 5 + dataset = dataset_ops.Dataset.range(n) + dataset = dataset.apply( + testing.assert_next(["ParallelMap", "MemoryCacheImpl"]) + ) + + k = 50 + for _ in range(k): + dataset = dataset.map( + lambda x: 2 * x, + num_parallel_calls=dataset_ops.AUTOTUNE, + ) + + dataset = dataset.cache() + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_optimization.map_fusion = True + dataset = dataset.with_options(options) + + self.assertDatasetProduces( + dataset, + expected_output=[x * 2**k for x in range(n)], + assert_items_equal=True, + ) + + @combinations.generate(test_base.default_test_combinations()) + def testControlInputs(self): + def f(x): + with ops.control_dependencies([check_ops.assert_type(x, dtypes.int64)]): + return 2 * x + + n = 5 + dataset = dataset_ops.Dataset.range(n) + dataset = dataset.apply( + testing.assert_next(["ParallelMap", "MemoryCacheImpl"]) + ) + dataset = dataset.map(f, num_parallel_calls=dataset_ops.AUTOTUNE) + dataset = dataset.map(f, num_parallel_calls=dataset_ops.AUTOTUNE) + + dataset = dataset.cache() + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_optimization.map_fusion = True + dataset = dataset.with_options(options) + self.assertDatasetProduces( + dataset, + expected_output=[x * 4 for x in range(n)], + assert_items_equal=True, + ) @combinations.generate( combinations.times( diff --git a/tensorflow/python/data/experimental/kernel_tests/service/BUILD b/tensorflow/python/data/experimental/kernel_tests/service/BUILD index cfac30fe0dbb47..f43db5fd64b852 100644 --- a/tensorflow/python/data/experimental/kernel_tests/service/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/service/BUILD @@ -243,6 +243,27 @@ tf_py_strict_test( ], ) +tf_py_strict_test( + name = "distributed_save_load_test", + size = "medium", + srcs = ["distributed_save_load_test.py"], + shard_count = 8, + deps = [ + ":test_base", + "//tensorflow/python/data/experimental/ops:data_service_ops", + "//tensorflow/python/data/experimental/ops:distributed_save_op", + "//tensorflow/python/data/kernel_tests:checkpoint_test_base", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/framework:combinations", + "//tensorflow/python/framework:errors", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:test", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_strict_test( name = "distributed_save_ft_test", size = "medium", diff --git a/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_test.py b/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_test.py new file mode 100644 index 00000000000000..a7e21a5d939321 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/service/distributed_save_load_test.py @@ -0,0 +1,254 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for distributed save/load with the new load algorithm.""" + +import os +import shutil +import tempfile +import threading +import time +from typing import Callable, Optional + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.service import test_base as data_service_test_base +from tensorflow.python.data.experimental.ops import data_service_ops +from tensorflow.python.data.experimental.ops import distributed_save_op +from tensorflow.python.data.kernel_tests import checkpoint_test_base +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import load_op +from tensorflow.python.framework import combinations +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class TestSnapshot: + """Test data for snapshots.""" + + def __init__(self): + temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + self.path = os.path.join( + tempfile.mkdtemp(dir=temp_dir), "distributed_save_load_test") + + def __del__(self): + shutil.rmtree(self.path) + + +class DistributedSaveLoadTest( + data_service_test_base.TestBase, parameterized.TestCase): + """Tests for distributed save/load with the new load algorithm. + + TODO(b/297930782): Add fault tolerance tests. + """ + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + num_workers=[1, 3], + num_elements=[0, 10], + num_repetitions=[1, 3], + compression=[None, "AUTO", "GZIP"]))) + def test_save_load( + self, + num_workers: int, + num_elements: int, + num_repetitions: int, + compression: Optional[str]): + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=num_workers) + dataset = dataset_ops.Dataset.range(num_elements) + dataset = dataset.repeat(num_repetitions) + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + + # Unlike the old load op, v2 does not need to wait for snapshot to finish. + dataset = load_op._load_distributed_snapshot_v2(test_snapshot.path) + self.assertDatasetProduces( + dataset, + list(range(num_elements)) * num_repetitions, + assert_items_equal=True) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(num_workers=[1, 3]))) + def test_concurrent_save_load(self, num_workers: int): + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=num_workers) + + def load_thread_fn(): + dataset = load_op._load_distributed_snapshot_v2(test_snapshot.path) + self.assertDatasetProduces( + dataset, list(range(10)), assert_items_equal=True) + load_thread = threading.Thread(target=load_thread_fn, name="load_thread") + load_thread.start() + + def save_thread_fn(): + time.sleep(5) + dataset = dataset_ops.Dataset.range(10) + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + save_thread = threading.Thread(target=save_thread_fn, name="save_thread") + save_thread.start() + save_thread.join() + load_thread.join() + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(num_workers=[1, 3], num_elements=[0, 10]))) + def test_distributed_load(self, num_workers: int, num_elements: int): + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=num_workers) + dataset = dataset_ops.Dataset.range(num_elements) + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + + dataset = load_op._load_distributed_snapshot_v2(test_snapshot.path) + # TODO(b/297930782): Support dynamic sharding. + dataset = dataset.apply( + data_service_ops.distribute( + data_service_ops.ShardingPolicy.OFF, cluster.dispatcher_address())) + self.assertDatasetProduces( + dataset, + list(range(num_elements)) * num_workers, + assert_items_equal=True) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(num_workers=[1, 3]))) + def test_save_before_sample(self, num_workers: int): + num_elements = 10 + num_datasets = 3 + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=num_workers) + datasets = [ + dataset_ops.Dataset.range(num_elements) for i in range(num_datasets)] + for i, dataset in enumerate(datasets): + self.evaluate( + distributed_save_op.distributed_save( + dataset, + os.path.join(test_snapshot.path, f"dataset_{i}"), + cluster.dispatcher_address())) + + loaded_datasets = [] + for i in range(len(datasets)): + loaded_datasets.append( + load_op._load_distributed_snapshot_v2( + os.path.join(test_snapshot.path, f"dataset_{i}"))) + dataset = dataset_ops.Dataset.sample_from_datasets( + loaded_datasets, + weights=[1.0] * num_datasets, + stop_on_empty_dataset=False) + self.assertDatasetProduces( + dataset, + list(range(num_elements)) * num_datasets, + assert_items_equal=True) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(num_workers=[1, 3], num_repetitions=[1, 3]))) + def test_save_after_sample(self, num_workers: int, num_repetitions: int): + num_elements = 10 + num_datasets = 3 + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=num_workers) + datasets = [ + dataset_ops.Dataset.range(num_elements) for i in range(num_datasets)] + if num_repetitions > 1: + datasets = [dataset.repeat(num_repetitions) for dataset in datasets] + dataset = dataset_ops.Dataset.sample_from_datasets( + datasets, weights=[1.0] * num_datasets, stop_on_empty_dataset=False) + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + + dataset = load_op._load_distributed_snapshot_v2(test_snapshot.path) + self.assertDatasetProduces( + dataset, + list(range(num_elements)) * num_datasets * num_repetitions, + assert_items_equal=True) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(num_workers=[1, 3]))) + def test_enumerate(self, num_workers: int): + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers) + dataset = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]) + dataset = dataset.repeat(3) + dataset = dataset.enumerate() + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + + dataset = load_op._load_distributed_snapshot_v2(test_snapshot.path) + indexes, elements = map(list, zip(*self.getDatasetOutput(dataset))) + if num_workers == 1: + self.assertCountEqual(indexes, list(range(9))) + self.assertCountEqual(elements, [b"a", b"b", b"c"] * 3) + + @combinations.generate(test_base.default_test_combinations()) + def test_worker_failure(self): + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=1) + components = np.array([1.0, 2.0, 3.0, np.nan, 5.0]).astype(np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = dataset.map(lambda x: array_ops.check_numerics(x, "message")) + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + + with self.assertRaises(errors.InvalidArgumentError): + dataset = load_op._load_distributed_snapshot_v2(test_snapshot.path) + self.getDatasetOutput(dataset) + + +class SaveLoadCheckpointTest( + data_service_test_base.TestBase, + checkpoint_test_base.CheckpointTestBase, + parameterized.TestCase): + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations())) + def test_save_load_checkpoint(self, verify_fn: Callable[..., None]): + test_snapshot = TestSnapshot() + cluster = data_service_test_base.TestCluster(num_workers=1) + dataset = dataset_ops.Dataset.range(10) + self.evaluate( + distributed_save_op.distributed_save( + dataset, test_snapshot.path, cluster.dispatcher_address())) + + def _build_ds() -> dataset_ops.Dataset: + return load_op._load_distributed_snapshot_v2(test_snapshot.path) + + verify_fn(self, _build_ds, num_outputs=10) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index c3e153a4e775e6..cc604a2afebe22 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -56,6 +56,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/framework:dtypes", "//tensorflow/python/util:deprecation", @@ -194,6 +195,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/util:deprecation", @@ -349,6 +351,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", @@ -365,6 +368,7 @@ py_strict_library( ":error_ops", ":parsing_ops", "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:options", "//tensorflow/python/data/ops:readers", diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py index 2a8eaaae76afaa..e9dc2b49a0ea0d 100644 --- a/tensorflow/python/data/experimental/ops/counter.py +++ b/tensorflow/python/data/experimental/ops/counter.py @@ -14,6 +14,7 @@ # ============================================================================== """The Counter Dataset.""" from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.util import deprecation @@ -70,3 +71,14 @@ def CounterV1(start=0, step=1, dtype=dtypes.int64): Counter = CounterV2 else: Counter = CounterV1 + + +def _tf2_callback(): # pylint: disable=invalid-name + global Counter + if tf2.enabled(): + Counter = CounterV2 + else: + Counter = CounterV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index baf0862379cd4f..961bca376a9662 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -16,6 +16,7 @@ import enum import functools +from typing import Callable from tensorflow.core.protobuf import data_service_pb2 from tensorflow.python import tf2 @@ -435,17 +436,19 @@ def _parse_service(service) -> tuple[str, str]: return (protocol, address) -def _distribute(processing_mode, - service, - job_name=None, - consumer_index=None, - num_consumers=None, - max_outstanding_requests=None, - task_refresh_interval_hint_ms=None, - data_transfer_protocol=None, - compression="AUTO", - cross_trainer_cache=None, - target_workers="AUTO") -> dataset_ops.Dataset: +def _distribute( + processing_mode, + service, + job_name=None, + consumer_index=None, + num_consumers=None, + max_outstanding_requests=None, + task_refresh_interval_hint_ms=None, + data_transfer_protocol=None, + compression="AUTO", + cross_trainer_cache=None, + target_workers="AUTO", +) -> Callable[dataset_ops.Dataset, dataset_ops.Dataset]: """A transformation that moves dataset processing to the tf.data service. This transformation is similar to `distribute`, but supports additional @@ -529,16 +532,18 @@ def _apply_fn(dataset) -> dataset_ops.Dataset: # pylint: disable=missing-docstr @tf_export("data.experimental.service.distribute") -def distribute(processing_mode, - service, - job_name=None, - consumer_index=None, - num_consumers=None, - max_outstanding_requests=None, - data_transfer_protocol=None, - compression="AUTO", - cross_trainer_cache=None, - target_workers="AUTO") -> dataset_ops.Dataset: +def distribute( + processing_mode, + service, + job_name=None, + consumer_index=None, + num_consumers=None, + max_outstanding_requests=None, + data_transfer_protocol=None, + compression="AUTO", + cross_trainer_cache=None, + target_workers="AUTO", +) -> Callable[dataset_ops.Dataset, dataset_ops.Dataset]: """A transformation that moves dataset processing to the tf.data service. When you iterate over a dataset containing the `distribute` transformation, diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py index 4cf61f9d5c7f9b..7f1d97d6a0e90e 100644 --- a/tensorflow/python/data/experimental/ops/interleave_ops.py +++ b/tensorflow/python/data/experimental/ops/interleave_ops.py @@ -14,6 +14,7 @@ # ============================================================================== """Non-deterministic dataset transformations.""" from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.util import deprecation @@ -245,3 +246,16 @@ def choose_from_datasets_v1(datasets, else: choose_from_datasets = choose_from_datasets_v1 sample_from_datasets = sample_from_datasets_v1 + + +def _tf2_callback(): + global choose_from_datasets, sample_from_datasets + if tf2.enabled(): + choose_from_datasets = choose_from_datasets_v2 + sample_from_datasets = sample_from_datasets_v2 + else: + choose_from_datasets = choose_from_datasets_v1 + sample_from_datasets = sample_from_datasets_v1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py index 8e951ea962c3d9..a88f14a8063b42 100644 --- a/tensorflow/python/data/experimental/ops/random_ops.py +++ b/tensorflow/python/data/experimental/ops/random_ops.py @@ -16,6 +16,7 @@ import functools from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import random_op from tensorflow.python.util import deprecation @@ -44,3 +45,14 @@ def __init__(self, seed=None): RandomDataset = RandomDatasetV2 else: RandomDataset = RandomDatasetV1 + + +def _tf2_callback(): + global RandomDataset + if tf2.enabled(): + RandomDataset = RandomDatasetV2 + else: + RandomDataset = RandomDatasetV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py index 1ae47f4c9c8e70..75a4a9c39ffa50 100644 --- a/tensorflow/python/data/experimental/ops/readers.py +++ b/tensorflow/python/data/experimental/ops/readers.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import parsing_ops from tensorflow.python.data.ops import dataset_ops @@ -1220,3 +1221,20 @@ def __init__(self, driver_name, data_source_name, query, output_types): SqlDataset = SqlDatasetV1 make_batched_features_dataset = make_batched_features_dataset_v1 make_csv_dataset = make_csv_dataset_v1 + + +def _tf2_callback(): + global CsvDataset, SqlDataset, make_batched_features_dataset, make_csv_dataset + if tf2.enabled(): + CsvDataset = CsvDatasetV2 + SqlDataset = SqlDatasetV2 + make_batched_features_dataset = make_batched_features_dataset_v2 + make_csv_dataset = make_csv_dataset_v2 + else: + CsvDataset = CsvDatasetV1 + SqlDataset = SqlDatasetV1 + make_batched_features_dataset = make_batched_features_dataset_v1 + make_csv_dataset = make_csv_dataset_v1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index dc07984864b6b4..4182a927caf562 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -33,7 +33,7 @@ tf_py_strict_test( name = "batch_test", size = "medium", srcs = ["batch_test.py"], - shard_count = 4, + shard_count = 8, deps = [ ":checkpoint_test_base", ":test_base", @@ -509,6 +509,7 @@ tf_py_strict_test( deps = [ ":checkpoint_test_base", ":test_base", + "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/framework:combinations", "//tensorflow/python/framework:constant_op", diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py index 29c9cf72aea840..4a3becfd753faa 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_test.py @@ -352,6 +352,47 @@ def _build_ds(): verify_fn(self, _build_ds, num_outputs=20) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine(symbolic_checkpoint=[True], + num_skips=[3, 4]), + ) + ) + def testWithSkip(self, verify_fn, symbolic_checkpoint, num_skips): + """Test `.flat_map().skip()` checkpointing behavior. + + `SkipInternal` and `GetNextInternal` are separate functions + but with slighly different implementations. + Therefore, we should test this op's behavior when used with `.skip()`. + + Args: + verify_fn: Verify the correctness of this dataset's checkpointing. + symbolic_checkpoint: Whether symbolic checkpointing is turned on. + num_skips: `.skip(num_skips)` + """ + + def build_dataset(): + def my_map(x): + if x == 0: + return dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3]) + elif x == 1: + return dataset_ops.Dataset.from_tensor_slices([4, 5, 6, 7]) + else: + return dataset_ops.Dataset.from_tensor_slices([8, 9, 10, 11]) + + indices = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = indices.flat_map(my_map) + # Skip some elements + dataset = dataset.skip(num_skips) + + options = options_lib.Options() + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn(self, build_dataset, num_outputs=3 * 4 - num_skips) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/group_by_window_test.py b/tensorflow/python/data/kernel_tests/group_by_window_test.py index 36ba1659bbf981..461967f528f22c 100644 --- a/tensorflow/python/data/kernel_tests/group_by_window_test.py +++ b/tensorflow/python/data/kernel_tests/group_by_window_test.py @@ -16,6 +16,7 @@ from absl.testing import parameterized import numpy as np +from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops @@ -277,26 +278,6 @@ def testEmpty(self): "Window size must be greater than zero, but got 0."): print(self.evaluate(get_next())) - @combinations.generate(test_base.default_test_combinations()) - def testReduceFuncError(self): - components = np.random.randint(100, size=(200,)).astype(np.int64) - - def reduce_func(_, xs): - # Introduce an incorrect padded shape that cannot (currently) be - # detected at graph construction time. - return xs.padded_batch( - 4, - padded_shapes=(tensor_shape.TensorShape([]), - constant_op.constant([5], dtype=dtypes.int64) * -1)) - - dataset = dataset_ops.Dataset.from_tensor_slices(components) - dataset = dataset.map(lambda x: (x, ops.convert_to_tensor([x * x]))) - dataset = dataset.group_by_window( - key_func=lambda x, _: x % 2, reduce_func=reduce_func, window_size=32) - get_next = self.getNext(dataset) - with self.assertRaises(errors.InvalidArgumentError): - self.evaluate(get_next()) - @combinations.generate(test_base.default_test_combinations()) def testConsumeWindowDatasetMoreThanOnce(self): components = np.random.randint(50, size=(200,)).astype(np.int64) @@ -399,5 +380,73 @@ def test(self): verify_exhausted=False) +class GroupByWindowErrorMessageTest( + test_base.DatasetTestBase, parameterized.TestCase +): + + @combinations.generate(test_base.default_test_combinations()) + def testReduceFuncError(self): + components = np.random.randint(100, size=(200,)).astype(np.int64) + + def my_reduce_func(_, window_dataset): + # Introduce an incorrect padded shape that cannot (currently) be + # detected at graph construction time. + return window_dataset.padded_batch( + 4, + padded_shapes=( + tensor_shape.TensorShape([]), + constant_op.constant([5], dtype=dtypes.int64) * -1, + ), + ) + + dataset = dataset_ops.Dataset.from_tensor_slices(components) + dataset = dataset.map(lambda x: (x, ops.convert_to_tensor([x * x]))) + dataset = dataset.group_by_window( + key_func=lambda x, _: x % 2, reduce_func=my_reduce_func, window_size=32 + ) + get_next = self.getNext(dataset) + with self.assertRaises(errors.InternalError) as error: + self.evaluate(get_next()) + + msg = str(error.exception) + self.assertIn(error_codes_pb2.Code.Name(errors.INVALID_ARGUMENT), msg) + self.assertIn( + my_reduce_func.__name__, + msg, + "{} should show up in the error message".format( + my_reduce_func.__name__ + ), + ) + + @combinations.generate(test_base.default_test_combinations()) + def testPropagateUserDefinedFunctionErrorMessage(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0]) + + def a_cool_user_defined_reduce_func(unused_key, window_dataset): + it = iter(window_dataset) + l = [next(it) for _ in range(2)] # This causes OutOfRange error + return dataset_ops.Dataset.from_tensor_slices(l) + + dataset = dataset.group_by_window( + key_func=lambda x: 0, + window_size=2, + reduce_func=a_cool_user_defined_reduce_func, + ) + + get_next = self.getNext(dataset) + with self.assertRaisesRegex( + errors.InternalError, + ".*{}.*".format(a_cool_user_defined_reduce_func.__name__), + msg=( + "The name of user-defined-function should show up in the error" + " message" + ), + ): + # Loop over the dataset + with self.assertRaises(errors.OutOfRangeError): + while True: + self.evaluate(get_next()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/placement_test.py b/tensorflow/python/data/kernel_tests/placement_test.py index 6c9efc53f2486a..35f929737c6ff6 100644 --- a/tensorflow/python/data/kernel_tests/placement_test.py +++ b/tensorflow/python/data/kernel_tests/placement_test.py @@ -198,7 +198,7 @@ def create_iter(): create_iter() @combinations.generate(test_base.graph_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testIteratorOnDeviceGraphModeOneShotIterator(self): self.skipTest("TODO(b/169429285): tf.data.Dataset.make_one_shot_iterator " "does not support GPU placement.") @@ -230,7 +230,7 @@ def testIteratorOnDeviceGraphModeOneShotIterator(self): self.assertIn(b"GPU:0", self.evaluate(has_value_device)) @combinations.generate(test_base.graph_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testIteratorOnDeviceGraphModeInitializableIterator(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) @@ -259,7 +259,7 @@ def testIteratorOnDeviceGraphModeInitializableIterator(self): self.assertIn(b"GPU:0", self.evaluate(has_value_device)) @combinations.generate(test_base.eager_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testIterDatasetEagerModeWithExplicitDevice(self): @def_function.function @@ -274,7 +274,7 @@ def comp(): self.assertEqual(result.numpy(), 45) @combinations.generate(test_base.eager_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testFunctionInliningColocation(self): @def_function.function diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index b7706ef00699c5..b5d55728b53b47 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -98,6 +98,7 @@ py_strict_library( "//tensorflow/python/autograph/operators:py_builtins", "//tensorflow/python/checkpoint", "//tensorflow/python/checkpoint:checkpoint_management", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/experimental/ops:take_while_ops", "//tensorflow/python/data/experimental/service:_pywrap_snapshot_utils", "//tensorflow/python/data/util:convert", @@ -114,6 +115,7 @@ py_strict_library( "//tensorflow/python/framework:composite_tensor", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:function", "//tensorflow/python/framework:none_tensor", "//tensorflow/python/framework:ops", @@ -192,7 +194,6 @@ py_strict_library( "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/trackable:base", "//tensorflow/python/training:saver", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:nest", @@ -268,6 +269,7 @@ py_strict_library( ":dataset_ops", ":structured_function", "//tensorflow/python:tf2", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/util:convert", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 08ea8693d1cbbc..358b316ea0bd3b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -28,6 +28,7 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import struct_pb2 from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_autograph from tensorflow.python.data.ops import debug_mode from tensorflow.python.data.ops import iterator_ops @@ -4212,6 +4213,17 @@ def with_options(self, options, name=None) -> "DatasetV1Adapter": Dataset = DatasetV1 +def _tf2_callback(): + global Dataset + if tf2.enabled(): + Dataset = DatasetV2 + else: + Dataset = DatasetV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) + + class DatasetV1Adapter(DatasetV1): """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 8c09060ab85976..6db3abca84c880 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -39,7 +39,6 @@ from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.trackable import base as trackable from tensorflow.python.training.saver import BaseSaverBuilder -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import deprecation from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import tf_export @@ -1013,5 +1012,4 @@ def get_next_as_optional(iterator): return iterator.get_next_as_optional() -_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator) iterator_autograph.register_overrides() diff --git a/tensorflow/python/data/ops/load_op.py b/tensorflow/python/data/ops/load_op.py index bb25e08feb7060..bec48f81349ccd 100644 --- a/tensorflow/python/data/ops/load_op.py +++ b/tensorflow/python/data/ops/load_op.py @@ -15,6 +15,8 @@ """Implementation of LoadDataset in Python.""" import multiprocessing import os +import time +from typing import Optional from google.protobuf import message from google.protobuf import text_format @@ -22,6 +24,9 @@ from tensorflow.python.data.experimental.service import _pywrap_snapshot_utils from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import structured_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.platform import gfile # TODO(b/238903802): Use TypeSpec serialization methods directly. @@ -31,22 +36,6 @@ def _load(path, element_spec, compression, reader_func): """Loads dataset from tf.data snapshot.""" - def _get_distributed_snapshot_metadata(): - """Reads the distributed snapshot metadata. - - Returns: - DistributedSnapshotMetadata if the snapshot is a distributed snapshot. - Returns None if it is a non-distributed snapshot. - """ - try: - with gfile.GFile( - _pywrap_snapshot_utils.TF_DATA_SnapshotMetadataFilePath(path), "r" - ) as f: - return text_format.ParseLines( - f, snapshot_pb2.DistributedSnapshotMetadata()) - except (text_format.ParseError, message.DecodeError, UnicodeDecodeError): - return None - if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, @@ -59,7 +48,7 @@ def _get_distributed_snapshot_metadata(): encoded_spec = f.read() element_spec = _parse_element_spec(encoded_spec) - distributed_snapshot_metadata = _get_distributed_snapshot_metadata() + distributed_snapshot_metadata = _load_distributed_snapshot_metadata(path) if distributed_snapshot_metadata: _validate_snapshot( path, distributed_snapshot_metadata, element_spec, compression) @@ -68,6 +57,32 @@ def _get_distributed_snapshot_metadata(): return _LoadDataset(path, element_spec, compression, reader_func) +def _load_distributed_snapshot_metadata( + path: str, +) -> Optional[snapshot_pb2.DistributedSnapshotMetadata]: + """Reads the distributed snapshot metadata. + + Args: + path: Base path of the snapshot. + + Returns: + DistributedSnapshotMetadata if the snapshot is a distributed snapshot. + Returns None if it is a non-distributed snapshot. + """ + try: + with gfile.GFile( + _pywrap_snapshot_utils.TF_DATA_SnapshotMetadataFilePath(path), "r" + ) as f: + return text_format.ParseLines( + f, snapshot_pb2.DistributedSnapshotMetadata()) + except ( + errors.NotFoundError, + text_format.ParseError, + message.DecodeError, + UnicodeDecodeError): + return None + + def _load_distributed_snapshot(path, metadata, reader_func): """Loads a distributed snapshot.""" @@ -83,6 +98,46 @@ def _load_distributed_snapshot(path, metadata, reader_func): return reader_func(dataset) +def _load_distributed_snapshot_v2( + path: str, reader_func=None +) -> dataset_ops.Dataset: + """Load a distributed snapshot using the updated loading algorithm. + + The new version allows the load job to read the snapshot while it is being + written. + + TODO(b/297930782): Merge this into `_load` when it's ready. Currently, this is + for testing only. + + Args: + path: Base path of the snapshot. + reader_func: Optional. A function to control how to read data from shards. + If present, the function will be traced and executed as graph computation. + + Returns: + The loaded dataset. + """ + + if not reader_func: + reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda + lambda x: x, + cycle_length=multiprocessing.cpu_count(), + num_parallel_calls=dataset_ops.AUTOTUNE) + + metadata = _load_distributed_snapshot_metadata(path) + while not metadata: + time.sleep(2) + metadata = _load_distributed_snapshot_metadata(path) + + dataset = _ListSnapshotChunksDataset(path) + dataset = dataset.map( + lambda chunk_file: _SnapshotChunkDataset( # pylint:disable=g-long-lambda + chunk_file, + element_spec=_parse_element_spec(metadata.element_spec), + compression=metadata.compression)) + return reader_func(dataset) + + class _LoadDataset(dataset_ops.DatasetSource): """A dataset that loads previously saved dataset.""" @@ -127,6 +182,25 @@ def element_spec(self): return self._element_spec +class _ListSnapshotChunksDataset(dataset_ops.DatasetSource): + """A dataset for listing snapshot chunk files. + + It supports listing partially written snapshots. When a snapshot is being + written, it returns the currently available chunk files. + """ + + def __init__(self, snapshot_path: str): + self._snapshot_path = snapshot_path + variant_tensor = ged_ops.list_snapshot_chunks_dataset( + snapshot_path, **self._flat_structure + ) + super().__init__(variant_tensor) + + @property + def element_spec(self) -> tensor_spec.TensorSpec: + return tensor_spec.TensorSpec([], dtypes.string) + + def _validate_snapshot(path, metadata, element_spec, compression): """Validates a tf.data distributed snapshot. diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index 347b7a5c272973..566abb7b66eceb 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -16,6 +16,7 @@ import os from tensorflow.python import tf2 +from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import from_tensor_slices_op from tensorflow.python.data.ops import structured_function @@ -705,3 +706,18 @@ def _filenames(self, value): FixedLengthRecordDataset = FixedLengthRecordDatasetV1 TFRecordDataset = TFRecordDatasetV1 TextLineDataset = TextLineDatasetV1 + + +def _tf2_callback(): + global FixedLengthRecordDataset, TFRecordDataset, TextLineDataset + if tf2.enabled(): + FixedLengthRecordDataset = FixedLengthRecordDatasetV2 + TFRecordDataset = TFRecordDatasetV2 + TextLineDataset = TextLineDatasetV2 + else: + FixedLengthRecordDataset = FixedLengthRecordDatasetV1 + TFRecordDataset = TFRecordDatasetV1 + TextLineDataset = TextLineDatasetV1 + + +v2_compat.register_data_v2_callback(_tf2_callback) diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py index 706823b799b14a..2b38a4ca4d34ff 100644 --- a/tensorflow/python/debug/lib/debug_events_reader.py +++ b/tensorflow/python/debug/lib/debug_events_reader.py @@ -109,8 +109,15 @@ def _load_metadata_files(self): wall_times.append(debug_event.wall_time) run_ids.append(debug_event.debug_metadata.tfdbg_run_id) tensorflow_versions.append( - debug_event.debug_metadata.tensorflow_version) + debug_event.debug_metadata.tensorflow_version + ) file_versions.append(debug_event.debug_metadata.file_version) + except Exception as e: + raise errors.DataLossError( + None, + None, + "Error reading tfdbg metadata from paths %s" % metadata_paths, + ) from e finally: reader.close() self._starting_wall_time = wall_times[0] diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py index 0b7677f9d8cf13..aa0215387e2f6b 100644 --- a/tensorflow/python/distribute/integration_test/saved_model_test.py +++ b/tensorflow/python/distribute/integration_test/saved_model_test.py @@ -40,7 +40,6 @@ from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test -from tensorflow.python.framework import errors_impl from tensorflow.python.ops import lookup_ops _sixteen_worker_pool = strategy_combinations._deferred_pool_runner( @@ -684,16 +683,15 @@ def test_sharded_variable(self): self.assertAllEqual(self.load_and_run_v1(model_dir, {"x": 1}), [6, 6, 6, 6]) - def test_load_with_partitioner_raises_error(self): + def test_load_with_partitioner_works(self): model = self.Model() model_dir = self.get_temp_dir() tf.saved_model.save(model, model_dir) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, tf1.fixed_size_partitioner(2)) - with self.assertRaises(errors_impl.InvalidArgumentError): - with strategy.scope(): - tf.saved_model.load(model_dir) + with strategy.scope(): + tf.saved_model.load(model_dir) if __name__ == "__main__": diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 69b22392903a03..a07df8e337cb0c 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -929,10 +929,13 @@ def shutdown(self): if self._runner is not None: try: self._runner.join() + except unittest.SkipTest: + raise except Exception as e: # pylint: disable=broad-except - logging.error( + logging.exception( 'Ignoring exception when shutting down MultiProcessPoolRunner: %s', - e) + e, + ) self._runner = None def _start(self): diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index 12c9ed9aa3ed10..4f4e0a5cbf3eaa 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -438,7 +438,7 @@ def __getitem__(self, slice_spec): ) for i in range(len(self._variables)): if i == len(self._variables) - 1 or ( - s > self._var_offsets[i][0] and s < self._var_offsets[i + 1][0] + s >= self._var_offsets[i][0] and s < self._var_offsets[i + 1][0] ): return self._variables[i][ (s - self._var_offsets[i][0],) + slice_spec[1:] diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py index 4c83bc49b328db..797b9d066b45dc 100644 --- a/tensorflow/python/distribute/sharded_variable_test.py +++ b/tensorflow/python/distribute/sharded_variable_test.py @@ -569,14 +569,20 @@ def safe_sparse_lookup(): self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]]) def test_slicing(self): + data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], + [15, 16]] v = [ - variables_lib.Variable([[1, 2], [3, 4], [5, 6]]), - variables_lib.Variable([[7, 8], [9, 10], [11, 12]]), - variables_lib.Variable([[13, 14], [15, 16]]) + variables_lib.Variable(data[:3]), + variables_lib.Variable(data[3:6]), + variables_lib.Variable(data[6:]) ] sv = sharded_variable.ShardedVariable(v) empty = v[0][0:0] + # Test cases: all individual indices + for ix in range(len(data)): + self.assertAllEqual(sv[ix].numpy(), data[ix]) + # Test cases: positive step self.assertAllEqual(sv[:], array_ops.concat(v, axis=0)) self.assertAllEqual(sv[:2], [[1, 2], [3, 4]]) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 3c8f7c80a0d593..2923868b1280de 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -9,6 +9,7 @@ load( load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test") load( "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", "tf_py_logged_benchmark", ) @@ -113,6 +114,8 @@ cuda_py_strict_test( deps = [ ":pywrap_tensor_test_util", ":test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:test_lib", "//third_party/py/numpy", ], ) @@ -805,7 +808,7 @@ py_strict_library( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "benchmarks_test", srcs = ["benchmarks_test.py"], python_version = "PY3", @@ -896,7 +899,7 @@ tf_xla_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "remote_benchmarks_test", srcs = ["remote_benchmarks_test.py"], python_version = "PY3", diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index c4dc6d228c9bf3..a81fb37b013616 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -447,7 +447,7 @@ def testTapeNoOpGradient2By2(self): self.assertAllEqual(dy_dy.numpy(), constant_op.constant(1.0, shape=[2, 2]).numpy()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTapeNoOpGradientMultiTarget2By2(self): a_2_by_2 = constant_op.constant(2.0, shape=[2, 2]) with backprop.GradientTape(persistent=True) as tape: @@ -1648,7 +1648,7 @@ def grad_fn(x): self.assertIn('gradient_tape/my_scope/', op.name) self.assertEqual(num_sin_ops_found, 2) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecomputeGradWithDifferentShape(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) @@ -1681,7 +1681,7 @@ def outer_dict(x): self.assertAllEqual(y[1], 2.0) @parameterized.parameters([(True), (False)]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecomputeGradWithNestedFunctionAndWhileLoop(self, reduce_retracing): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) diff --git a/tensorflow/python/eager/benchmarks/BUILD b/tensorflow/python/eager/benchmarks/BUILD index 58878b11343e05..d50d428ed712f8 100644 --- a/tensorflow/python/eager/benchmarks/BUILD +++ b/tensorflow/python/eager/benchmarks/BUILD @@ -1,4 +1,7 @@ -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test") +load( + "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -6,7 +9,7 @@ package( licenses = ["notice"], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "kpi_benchmark_test", size = "medium", srcs = ["kpi_benchmark_test.py"], diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index aeda5a61594fd7..18f32cc1604186 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -305,7 +305,6 @@ class LogicalDevice( placement. device_type: String declaring the type of device such as "CPU" or "GPU". """ - pass @tf_export("config.LogicalDeviceConfiguration", @@ -688,6 +687,10 @@ def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): # Clear all the caches in case there are remote tensors in them. self._clear_caches() + # Also clear the device parsing cache since it caches the resolution of + # partial device names, which may become different due to the set_server_def + # call as we may have defined different devices. + _device_parsing_cache.clear() def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): """Update a server_def on the context. @@ -1378,9 +1381,13 @@ def add_function_def(self, fdef): fdef: A FunctionDef protocol buffer message. """ self.ensure_initialized() - fdef_string = fdef.SerializeToString() - pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, - len(fdef_string)) + if is_oss: + fdef_string = fdef.SerializeToString() + pywrap_tfe.TFE_ContextAddFunctionDef( + self._handle, fdef_string, len(fdef_string) + ) + else: + pywrap_tfe.TFE_ContextAddFunctionDefNoSerialization(self._handle, fdef) def get_function_def(self, name): """Get a function definition from the context. diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 70f6e0e90877b5..82cac2d18a53c8 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -336,7 +336,7 @@ def testJVPFunctionUsedByAccumulatorForOps(self): finally: pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFunctionCacheLimited(self): # Every time this loop is executed, it will create a slightly larger Tensor # and push it through Add's gradient. @@ -357,7 +357,7 @@ def testVariableUnwatchedZero(self): self.assertIsNone(acc.jvp(v)) self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero")) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFunctionReturnsResource(self): v = variables.Variable([[1.]]) x = constant_op.constant(1.) @@ -371,7 +371,7 @@ def f(a): y, _ = f(x) self.assertAllClose(2., acc.jvp(y)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMultipleWatchesAdd(self): x = constant_op.constant(-2.) with self.assertRaisesRegex(ValueError, "multiple times"): @@ -387,7 +387,7 @@ def testMultipleWatchesAdd(self): self.assertAllClose(24., acc.jvp(x)) self.assertAllClose(24. * 3., acc.jvp(y)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testReenter(self): x = constant_op.constant(-2.) with forwardprop.ForwardAccumulator(x, 1.5) as acc: @@ -403,7 +403,7 @@ def testReenter(self): yy = y * y self.assertAllClose(6. * -8. * 2., acc.jvp(yy)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDeadTensorsJVPCleared(self): x = array_ops.ones([100]) x_weak = weakref.ref(x) @@ -424,14 +424,14 @@ def testDeadTensorsJVPCleared(self): self.assertIsNone(derived_tensor_weak()) self.assertIsNone(derived_tensor_grad_weak()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testJVPManual(self): primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),), (constant_op.constant(0.2),)) self.assertAllClose(math_ops.sin(0.1), primal) self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNumericHigherOrder(self): def f(x): @@ -448,7 +448,7 @@ def f(x): satol=1e-3, ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNumericHigherOrderFloat64(self): def f(x): @@ -462,7 +462,7 @@ def f(x): [constant_op.constant([[2.0, 3.0], [1.0, 4.0]], dtype=dtypes.float64)], order=3) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testCustomGradient(self): @custom_gradient.custom_gradient @@ -475,7 +475,7 @@ def grad(dy): _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) - # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly + # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly() # fails around this test? def testExceptionCustomGradientRecomputeGradForward(self): @@ -563,7 +563,7 @@ def grad(dy): ("Order{}".format(order), order, expected) for order, expected in enumerate(_X11_35_DERIVATIVES) ]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testHigherOrderPureForward(self, order, expected): def _forwardgrad(f): @@ -606,7 +606,7 @@ def f(x): self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp) self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testJVPPacking(self): two = constant_op.constant(2.) primal_in = constant_op.constant(1.) @@ -688,7 +688,7 @@ def _expected(mat, tangent): self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1)) self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testHVPMemory(self): def fun(x): @@ -698,7 +698,7 @@ def fun(x): tangents = constant_op.constant([3., 4., 5.]) _hvp(fun, (primals,), (tangents,)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testHVPCorrectness(self): def fun(x): @@ -725,7 +725,7 @@ def fun(x): self.assertAllClose(backback_hvp, forwardback_hvp_eager) self.assertAllClose(backback_hvp, forwardback_hvp_function) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testShouldRecordAndStopRecord(self): c = constant_op.constant(1.) c_tangent = constant_op.constant(2.) @@ -747,7 +747,7 @@ def testShouldRecordAndStopRecord(self): self.assertIsNone(acc.jvp(d)) self.assertIsNone(tape.gradient(d, c)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecordingSelectively(self): c = constant_op.constant(1.) c_tangent = constant_op.constant(2.) @@ -774,7 +774,7 @@ def testRecordingSelectively(self): self.assertIsNone(tape.gradient(d, c)) self.assertAllClose(3., tape.gradient(e, c)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testOpWithNoTrainableOutputs(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) @@ -847,7 +847,7 @@ def testBackwardOverForward(self, forward_prop_first): self.assertTrue(record.should_record_backprop((acc.jvp(d),))) self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecordingWithJVPIndices(self): c = constant_op.constant(1.) with forwardprop.ForwardAccumulator(c, 10.) as acc: @@ -861,7 +861,7 @@ def testRecordingWithJVPIndices(self): None, (((0, 1),),)) self.assertAllClose(3., acc.jvp(d)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testSpecialForwardFunctionUsed(self): c = constant_op.constant(1.) d = constant_op.constant(2.) @@ -875,7 +875,7 @@ def testSpecialForwardFunctionUsed(self): lambda x: [x]) self.assertAllClose(-20., acc.jvp(e)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testVariableWatched(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) @@ -1015,25 +1015,25 @@ def _fprop_cond(k, y): class ControlFlowTests(test.TestCase): - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testOfFunctionWhile(self): y = constant_op.constant(1.) with forwardprop.ForwardAccumulator(y, 1.) as acc: self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testOfFunctionCond(self): y = constant_op.constant(1.) with forwardprop.ForwardAccumulator(y, 1.) as acc: self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y))) self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testInFunctionWhile(self): self.assertAllClose( 10., _fprop_while(constant_op.constant(5), constant_op.constant(1.))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testInFunctionCond(self): self.assertAllClose( 3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.))) diff --git a/tensorflow/python/eager/memory_tests/memory_test.py b/tensorflow/python/eager/memory_tests/memory_test.py index ee5104ef27b343..3503058b0012cd 100644 --- a/tensorflow/python/eager/memory_tests/memory_test.py +++ b/tensorflow/python/eager/memory_tests/memory_test.py @@ -61,7 +61,7 @@ def graph(x): memory_test_util.assert_no_leak( f, num_iters=1000, increase_threshold_absolute_mb=30) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedFunctionsDeleted(self): @def_function.function diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD index e98fe42f55917e..2160b5a8f91c2d 100644 --- a/tensorflow/python/eager/polymorphic_function/BUILD +++ b/tensorflow/python/eager/polymorphic_function/BUILD @@ -118,7 +118,6 @@ py_strict_library( "//tensorflow/python/profiler:trace", "//tensorflow/python/trackable:base", "//tensorflow/python/types:core", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:nest", "//tensorflow/python/util:object_identity", diff --git a/tensorflow/python/eager/polymorphic_function/concrete_function.py b/tensorflow/python/eager/polymorphic_function/concrete_function.py index 3bbc4deeca4aaf..a68acdd94d40e0 100644 --- a/tensorflow/python/eager/polymorphic_function/concrete_function.py +++ b/tensorflow/python/eager/polymorphic_function/concrete_function.py @@ -47,7 +47,6 @@ from tensorflow.python.profiler import trace from tensorflow.python.trackable import base as trackable from tensorflow.python.types import core -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -1735,11 +1734,6 @@ def _export_to_saved_model_graph(self, object_map, tensor_map, return [] -_pywrap_utils.RegisterType("Tensor", tensor_lib.Tensor) -_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) -_pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices) - - class ConcreteFunctionGarbageCollector: """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py index 64aab16798ebf4..663562a347b59f 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py @@ -3833,7 +3833,7 @@ def testMethodReferenceCycles(self): # function itself is not involved in a reference cycle. self.assertIs(None, weak_fn()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testErrorMessageWhenGraphTensorIsPassedToEager(self): @polymorphic_function.function diff --git a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py index 96ea55beeb8077..42d8091ed960d5 100644 --- a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py +++ b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py @@ -385,7 +385,7 @@ def sum_gather(): expected = self.evaluate(sum_gather()) self.assertAllEqual(expected, self.evaluate(defined())) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testCallOptionsMemory(self): @compiled_fn def model(x): diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index dc6db8b6e78962..306e275c63905c 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -1095,6 +1095,15 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType"); return nullptr; } +#if PY_VERSION_HEX >= 0x030B0000 + // Py_TPFLAGS_MANAGED_DICT is turned on by PyType_FromSpecWithBases by + // default. It tells Python that the class's __dict__ should be managed by VM, + // but EagerTensor sets a `tp_dictoffset` (below) to explicitly manage the + // dict. See: + // - https://docs.python.org/3/c-api/typeobj.html#c.Py_TPFLAGS_MANAGED_DICT + // - https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_dictoffset + EagerTensorType->tp_flags &= ~Py_TPFLAGS_MANAGED_DICT; +#endif EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict); EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer; #else diff --git a/tensorflow/python/eager/pywrap_tensor_test.py b/tensorflow/python/eager/pywrap_tensor_test.py index a684a80658fa10..c1539b24f802c3 100644 --- a/tensorflow/python/eager/pywrap_tensor_test.py +++ b/tensorflow/python/eager/pywrap_tensor_test.py @@ -17,6 +17,18 @@ import numpy as np from tensorflow.python.eager import pywrap_tensor_test_util as util from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util + + +class MyPythonObject: + pass + + +def my_layer(x): + y = x**2 + y.my_dynamic_attribute = MyPythonObject() + return y class PywrapTensorTest(test.TestCase): @@ -26,6 +38,14 @@ def testGetScalarOne(self): self.assertIsInstance(result, np.ndarray) self.assertAllEqual(result, 1.0) + @test_util.assert_no_new_pyobjects_executing_eagerly() + def test_no_leak(self): + x = constant_op.constant([1, 2, 3]) + layer = my_layer(x) + for _ in range(int(1e2)): + layer = my_layer(x) + self.assertIsNotNone(layer) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index ea5e6006b9fa24..532d7f1555521f 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -86,7 +86,7 @@ def testNumpyValue(self): t = _create_tensor(values) self.assertAllEqual(values, t) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNumpyDtypeSurvivesThroughTensorConversion(self): scalar_creators = [np.int32, np.int64, np.float32, np.float64] conversion_functions = [ops.convert_to_tensor, constant_op.constant] @@ -359,7 +359,7 @@ def testConvertToTensorAllowsOverflow(self): _ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8) @test_util.run_in_graph_and_eager_modes - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testConvertToTensorNumpyZeroDim(self): for np_type, dtype in [(np.int32, dtypes.int32), (np.half, dtypes.half), (np.float32, dtypes.float32)]: @@ -370,7 +370,7 @@ def testConvertToTensorNumpyZeroDim(self): self.assertAllEqual(x, [65, 16]) @test_util.run_in_graph_and_eager_modes - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testConvertToTensorNumpyScalar(self): x = ops.convert_to_tensor([ np.array(321, dtype=np.int64).item(), @@ -422,19 +422,19 @@ def testMemoryviewIsReadonly(self): t = constant_op.constant([0.0]) self.assertTrue(memoryview(t).readonly) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMemoryviewScalar(self): t = constant_op.constant(42.0) self.assertAllEqual( np.array(memoryview(t)), np.array(42.0, dtype=np.float32)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMemoryviewEmpty(self): t = constant_op.constant([], dtype=np.float32) self.assertAllEqual(np.array(memoryview(t)), np.array([])) @test_util.run_gpu_only - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMemoryviewCopyToCPU(self): with ops.device("/device:GPU:0"): t = constant_op.constant([0.0]) @@ -620,7 +620,7 @@ def testSliceDimOutOfRange(self): "but tensor at index 2 has rank 0"): pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTensorDir(self): t = array_ops.ones(1) t.test_attr = "Test" @@ -639,7 +639,7 @@ def testNonRectangularPackAsConstant(self): with self.assertRaisesRegex(ValueError, "non-rectangular Python sequence"): constant_op.constant(l) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFloatAndIntAreConvertibleToComplex(self): a = [[1., 1], [1j, 2j]] np_value = np.array(a, dtype=np.complex128) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 65228aeb2bbe19..5a641aba2da70f 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -224,7 +224,7 @@ def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): _lift_unlifted_variables(fn_graph, variable_holder) # We call __init__ after lifting variables so that the function's signature # properly reflects the new captured inputs. - for f in fn_graph.as_graph_def().library.function: + for f in fn_graph.as_graph_def(use_pybind11_proto=True).library.function: context.context().add_function_def(f) self._signature = signature function_type = function_type_lib.from_structured_signature( diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi index ad94fa3e713c7a..fbf7124eac2f0a 100644 --- a/tensorflow/python/flags_pybind.pyi +++ b/tensorflow/python/flags_pybind.pyi @@ -19,6 +19,7 @@ class Flag: def value(self) -> bool: ... class Flags: + enable_aggressive_constant_replication: Flag enable_nested_function_shape_inference: Flag enable_quantized_dtypes_training: Flag graph_building_optimization: Flag diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 43dd8056fd6cfd..a697162b4b91cf 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -19,8 +19,12 @@ load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_protos_grappler") # @unused load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_xla_deps_py") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test") +load( + "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", +) -visibility = tf_python_framework_friends() +visibility = tf_python_framework_friends() # buildifier: disable=package-on-top package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -233,10 +237,11 @@ tf_cc_test( ], ) +# Do not depend on this rule! Depend on the fine-grained sub-targets instead. py_strict_library( name = "for_generated_wrappers", - deprecation = "Depending on this target can cause build dependency cycles. Depend on the fine-grained sub-targets instead.", srcs_version = "PY3", + tags = ["avoid_dep"], visibility = ["//visibility:public"], deps = [ ":byte_swap_tensor", @@ -254,13 +259,14 @@ py_strict_library( ], ) -# What is needed for tf_gen_op_wrapper_py. This is the same as -# "for_generated_wrappers" minus the "function" dep. This is to avoid -# circular dependencies, as "function" uses generated op wrappers. +# This rule should only be depended on by tf_gen_op_wrapper_py. +# Do not depend on this rule! Depend on the fine-grained sub-targets instead. +# This is the same as "for_generated_wrappers" minus the "function" dep. +# This is to avoid circular dependencies, as "function" uses generated op wrappers. py_strict_library( name = "for_generated_wrappers_v2", - deprecation = "Depending on this target can cause build dependency cycles. Depend on the fine-grained sub-targets instead.", srcs_version = "PY3", + tags = ["avoid_dep"], visibility = ["//visibility:public"], deps = [ ":byte_swap_tensor", @@ -296,81 +302,6 @@ py_strict_library( ], ) -py_strict_library( - name = "framework", - deprecation = "This target has been split. Depend on the sub-targets instead.", - srcs_version = "PY3", - visibility = visibility + ["//tensorflow:internal"], - deps = [ - ":_errors_test_helper", - ":_pywrap_python_api_dispatcher", - ":_pywrap_python_api_info", - ":_pywrap_python_api_parameter_converter", - ":_pywrap_python_op_gen", - ":byte_swap_tensor", - ":c_api_util", - ":composite_tensor", - ":config", - ":cpp_shape_inference_proto_py", - ":device", - ":dtypes", - ":error_interpolation", - ":errors", - ":fast_tensor_util", - ":for_generated_wrappers", - ":framework_lib", - ":function", - ":graph_io", - ":graph_util", - ":importer", - ":indexed_slices", - ":load_library", - ":meta_graph", - ":op_def_registry", - ":ops", - ":random_seed", - ":sparse_tensor", - ":tensor", - ":tensor_conversion_registry", - ":tensor_shape", - ":tensor_spec", - ":tensor_util", - ":type_spec", - ":versions", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:_pywrap_py_exception_registry", - "//tensorflow/python:_pywrap_quantize_training", - "//tensorflow/python:pywrap_mlir", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:pywrap_tfe", - "//tensorflow/python:tf2", - "//tensorflow/python/client:_pywrap_debug_events_writer", - "//tensorflow/python/client:_pywrap_events_writer", - "//tensorflow/python/client:pywrap_tf_session", - "//tensorflow/python/eager:context", - "//tensorflow/python/lib/core:_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. - "//tensorflow/python/lib/io:file_io", - "//tensorflow/python/ops:control_flow_util", - "//tensorflow/python/platform:_pywrap_stacktrace_handler", - "//tensorflow/python/platform:tf_logging", - "//tensorflow/python/util:_pywrap_checkpoint_reader", - "//tensorflow/python/util:_pywrap_kernel_registry", - "//tensorflow/python/util:_pywrap_nest", - "//tensorflow/python/util:_pywrap_stat_summarizer", - "//tensorflow/python/util:_pywrap_tfprof", - "//tensorflow/python/util:_pywrap_transform_graph", - "//tensorflow/python/util:_pywrap_util_port", - "//tensorflow/python/util:_pywrap_utils", - "//tensorflow/python/util:compat", - "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:tf_export", - "//third_party/py/numpy", - "@pypi_packaging//:pkg", - ] + if_xla_available([ - "//tensorflow/python:_pywrap_tfcompile", - ]), -) - py_strict_library( name = "byte_swap_tensor", srcs = ["byte_swap_tensor.py"], @@ -405,7 +336,10 @@ py_strict_library( py_strict_library( name = "constant_op", - srcs = ["constant_op.py"], + srcs = [ + "constant_op.py", + "constant_tensor_conversion.py", + ], srcs_version = "PY3", visibility = visibility + [ "//smartass:__subpackages__", @@ -776,6 +710,23 @@ py_strict_library( ], ) +py_strict_library( + name = "override_binary_operator", + srcs = ["override_binary_operator.py"], + srcs_version = "PY3", + deps = [ + ":dtypes", + ":ops", + ":tensor", + ":tensor_shape", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops/numpy_ops:np_dtypes", + "//tensorflow/python/util:nest", + "//tensorflow/python/util:traceback_utils", + "//third_party/py/numpy", + ], +) + cc_library( name = "py_context_manager", srcs = ["py_context_manager.cc"], @@ -1652,6 +1603,7 @@ py_strict_library( ":constant_op", ":dtypes", ":ops", + ":override_binary_operator", ":tensor", ":tensor_shape", ":tensor_spec", @@ -1662,10 +1614,10 @@ py_strict_library( "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:tf2", "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:math_ops_gen", "//tensorflow/python/ops:sparse_ops_gen", "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/types:internal", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], @@ -1723,7 +1675,6 @@ py_strict_library( visibility = visibility, deps = [ "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_export", ], @@ -1869,7 +1820,6 @@ pytype_strict_library( "//tensorflow/python/types:core", "//tensorflow/python/types:internal", "//tensorflow/python/types:trace", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:nest", @@ -2000,7 +1950,9 @@ pytype_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", - ], + ] + if_xla_available([ + "//tensorflow/python:_pywrap_tfcompile", + ]), ) pytype_strict_library( @@ -2013,10 +1965,12 @@ pytype_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", - ], + ] + if_xla_available([ + "//tensorflow/python:_pywrap_tfcompile", + ]), ) -py_strict_library( +pytype_strict_library( name = "stack", srcs = ["stack.py"], visibility = visibility + ["//tensorflow:internal"], @@ -2057,7 +2011,6 @@ py_strict_library( "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/types:core", "//tensorflow/python/types:internal", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:object_identity", @@ -2103,7 +2056,7 @@ py_strict_library( ], ) -py_strict_library( +pytype_strict_library( name = "traceable_stack", srcs = ["traceable_stack.py"], srcs_version = "PY3", @@ -2129,7 +2082,7 @@ py_strict_library( deps = [], ) -py_strict_library( +pytype_strict_library( name = "test_lib", srcs = ["test_util.py"], srcs_version = "PY3", @@ -3254,7 +3207,7 @@ tf_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "graph_building_benchmark", size = "medium", srcs = ["graph_building_benchmark.py"], diff --git a/tensorflow/python/framework/composite_tensor.py b/tensorflow/python/framework/composite_tensor.py index 05b4f672793f3e..6e1651ab5e7b88 100644 --- a/tensorflow/python/framework/composite_tensor.py +++ b/tensorflow/python/framework/composite_tensor.py @@ -17,7 +17,6 @@ import abc from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -99,9 +98,6 @@ def _convert_variables_to_tensors(self): return self -_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor) - - def replace_composites_with_components(structure): """Recursively replaces CompositeTensors with their components. diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 9c2d8d21a7c0b8..1371d6495e2e64 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -25,6 +25,11 @@ from tensorflow.core.protobuf import struct_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import execute +# Import constant_tensor_conversion.py to register tensor conversion functions +# for builtins. These functions were previously in this file, but were +# refactored out so they can be registered at TF import time without importing +# all of constant_op.py. +from tensorflow.python.framework import constant_tensor_conversion # pylint: disable=unused-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib @@ -329,24 +334,6 @@ def is_constant(tensor_or_op): return op.type == "Const" -def _constant_tensor_conversion_function(v, dtype=None, name=None, - as_ref=False): - _ = as_ref - return constant(v, dtype=dtype, name=name) - -# Register the conversion function for the "unconvertible" types -# as a conversion to a constant. -tensor_conversion_registry.register_tensor_conversion_function_internal( - tensor_conversion_registry._CONSTANT_OP_CONVERTIBLES, # pylint: disable=protected-access - _constant_tensor_conversion_function, - 0) - -tensor_conversion_registry.register_tensor_conversion_function( - (list, tuple), _constant_tensor_conversion_function, 100) -tensor_conversion_registry.register_tensor_conversion_function( - object, _constant_tensor_conversion_function, 200) - - def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None, diff --git a/tensorflow/python/framework/constant_tensor_conversion.py b/tensorflow/python/framework/constant_tensor_conversion.py new file mode 100644 index 00000000000000..c02cd37c2ac9da --- /dev/null +++ b/tensorflow/python/framework/constant_tensor_conversion.py @@ -0,0 +1,45 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tensor conversion factory functions for builtins to constant Tensors.""" + +from tensorflow.python.framework import tensor_conversion_registry + + +# Factory function for tensor conversion for builtins. Import constant_op.py +# in-line so that it is only imported when it is needed. This file is imported +# at TF import time, thus that helps reduce import slowness. +def _constant_tensor_conversion_function( + v, dtype=None, name=None, as_ref=False +): + from tensorflow.python.framework import constant_op # pylint: disable=g-import-not-at-top + + _ = as_ref + return constant_op.constant(v, dtype=dtype, name=name) + + +# Register the conversion function for the "unconvertible" types +# as a conversion to a constant. +tensor_conversion_registry.register_tensor_conversion_function_internal( + tensor_conversion_registry._CONSTANT_OP_CONVERTIBLES, # pylint: disable=protected-access + _constant_tensor_conversion_function, + 0, +) + +tensor_conversion_registry.register_tensor_conversion_function( + (list, tuple), _constant_tensor_conversion_function, 100 +) +tensor_conversion_registry.register_tensor_conversion_function( + object, _constant_tensor_conversion_function, 200 +) diff --git a/tensorflow/python/framework/experimental/BUILD b/tensorflow/python/framework/experimental/BUILD index 2c146b1f865e6e..2d7a8f11129a7f 100644 --- a/tensorflow/python/framework/experimental/BUILD +++ b/tensorflow/python/framework/experimental/BUILD @@ -2,6 +2,10 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_python_pybind_extension") +load( + "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -178,7 +182,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "graph_building_test", size = "small", srcs = ["graph_building_test.py"], diff --git a/tensorflow/python/framework/override_binary_operator.py b/tensorflow/python/framework/override_binary_operator.py new file mode 100644 index 00000000000000..6e5081b8811b70 --- /dev/null +++ b/tensorflow/python/framework/override_binary_operator.py @@ -0,0 +1,169 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Binary operator override class for Tensor overrides.""" +import numbers +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops.numpy_ops import np_dtypes +from tensorflow.python.util import nest +from tensorflow.python.util import traceback_utils + + +def _maybe_get_dtype(x): + """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" + # Don't put np.ndarray in this list, because np.result_type looks at the + # value (not just dtype) of np.ndarray to decide the result type. + if isinstance(x, numbers.Real): + return x + if isinstance(x, tensor_lib.Tensor): + return x.dtype.as_numpy_dtype + if isinstance(x, dtypes.DType): + return x.as_numpy_dtype + if isinstance(x, tensor_shape.TensorShape): + return np.int32 + if isinstance(x, (list, tuple)): + raise ValueError(f"Cannot determine dtype. Got sequence {x}.") + return x + + +def maybe_promote_tensors(*tensors, force_same_dtype=False): + """Promotes tensors if numpy style promotion is enabled. + + This function promotes `tensors` according to numpy promotion rules + if numpy style promotion is enabled. Otherwise, if + `force_same_dtype` is `True`, it force-casts `tensors[1:]` to + `tensor[0]`'s dtype. Note that this force-cast can be problematic. + For example, when some `tensors[1:]` elements can be silently + downcasted. + + Args: + *tensors: the list of tensors to promote. + force_same_dtype: bool (optional, default to `False`). When numpy + style promotion is disabled and `force_same_dtype` is `True`, + this function will force-casts `tensors[1:]` to `tensor[0]`'s + dtype (which could be problematic). + + Returns: + The promoted list of tensors. + """ + if ops.is_auto_dtype_conversion_enabled(): + return tensors + if not tensors: + return tensors + if not ops.is_numpy_style_type_promotion(): + if not force_same_dtype: + return tensors + promoted_tensors = [] + promoted_tensors.append(tensors[0]) + dtype = tensors[0].dtype.base_dtype + for tensor in tensors[1:]: + promoted_tensors.append( + ops.convert_to_tensor(tensor, dtype, name="x")) + return promoted_tensors + result_type = np_dtypes._result_type( # pylint: disable=protected-access + *[_maybe_get_dtype(x) for x in nest.flatten(tensors)]) + def _promote_or_cast(x): + if isinstance(x, tensor_lib.Tensor): + x = gen_math_ops.cast(x, result_type) + else: + x = ops.convert_to_tensor(x, result_type) + return x + return [_promote_or_cast(x) for x in tensors] + + +# pylint: disable=protected-access +def override_binary_operator_helper( + func, op_name, clazz_object=tensor_lib.Tensor): + """Register operators with different tensor and scalar versions. + + If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices, + sp_values, sp_shape, dense)` and outputs `(new_sp_values)`. + + Args: + func: the operator + op_name: name of the operator being overridden + clazz_object: class to override for. Either `Tensor` or `SparseTensor`. + """ + + @traceback_utils.filter_traceback + def binary_op_wrapper(x, y): + with ops.name_scope(None, op_name, [x, y]) as name: + try: + # force_same_dtype=False to preserve existing TF behavior + # TODO(b/178860388): Figure out why binary_op_wrapper and + # r_binary_op_wrapper use different force_same_dtype values. + x, y = maybe_promote_tensors(x, y) + return func(x, y, name=name) + except (TypeError, ValueError) as e: + # Even if dispatching the op failed, the RHS may be a tensor aware + # object that can implement the operator with knowledge of itself + # and the tensor. + # If the RHS is not tensor aware we still want to raise the + # original error from the LHS, because it may be more + # informative. + if hasattr(type(y), "__r%s__" % op_name): + try: + r_op = getattr(y, "__r%s__" % op_name) + out = r_op(x) + if out is NotImplemented: + raise + return out + except (TypeError, ValueError): + raise e + else: + raise + + @traceback_utils.filter_traceback + def binary_op_wrapper_sparse(sp_x, y): + with ops.name_scope(None, op_name, [sp_x, y]) as name: + y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y") + # use the passed-in SparseTensor class to avoid having to import + # SparseTensor, which would cause a cyclic dep with math_ops + return clazz_object( + sp_x.indices, + func(sp_x.indices, sp_x.values, sp_x.dense_shape, y, name=name), + sp_x.dense_shape) + + @traceback_utils.filter_traceback + def r_binary_op_wrapper(y, x): + with ops.name_scope(None, op_name, [x, y]) as name: + # TODO(b/178860388): Figure out why binary_op_wrapper and + # r_binary_op_wrapper use different force_same_dtype values. + y, x = maybe_promote_tensors(y, x, force_same_dtype=True) + return func(x, y, name=name) + + # Propagate func.__doc__ to the wrappers + try: + doc = func.__doc__ + except AttributeError: + doc = None + binary_op_wrapper.__doc__ = doc + r_binary_op_wrapper.__doc__ = doc + binary_op_wrapper_sparse.__doc__ = doc + + if clazz_object is tensor_lib.Tensor: + clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper) + del binary_op_wrapper + clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper) + del r_binary_op_wrapper + else: + clazz_object._override_operator("__%s__" % op_name, + binary_op_wrapper_sparse) + del binary_op_wrapper_sparse diff --git a/tensorflow/python/framework/python_api_dispatcher.cc b/tensorflow/python/framework/python_api_dispatcher.cc index ae50c87dc334be..805007dc4982bf 100644 --- a/tensorflow/python/framework/python_api_dispatcher.cc +++ b/tensorflow/python/framework/python_api_dispatcher.cc @@ -29,12 +29,28 @@ namespace py_dispatch { namespace { +PyObject* ImportTypeFromModule(const char* module_name, const char* type_name) { + static PyObject* given_type = [module_name, type_name]() { + PyObject* module = PyImport_ImportModule(module_name); + PyObject* attr = + module ? PyObject_GetAttrString(module, type_name) : nullptr; + if (attr == nullptr) { + PyErr_WriteUnraisable(nullptr); + PyErr_Clear(); + } + if (module) Py_DECREF(module); + return attr; + }(); + return given_type; +} + std::vector& GetRegisteredDispatchableTypes() { static std::vector* registered_dispatchable_types = new std::vector(); if (registered_dispatchable_types->empty()) { - static PyObject* composite_tensor = - swig::GetRegisteredPyObject("CompositeTensor"); + static PyObject* composite_tensor = ImportTypeFromModule( + "tensorflow.python.framework.composite_tensor", + "CompositeTensor"); Py_INCREF(composite_tensor); registered_dispatchable_types->push_back( Safe_PyObjectPtr(composite_tensor)); diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 0b870b54c96662..b578bf6e0a1545 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import override_binary_operator from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec @@ -32,10 +33,10 @@ from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec_registry from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.types import internal -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access @@ -371,7 +372,6 @@ def _is_eager(self): SparseTensorValue = collections.namedtuple("SparseTensorValue", ["indices", "values", "dense_shape"]) tf_export(v1=["SparseTensorValue"])(SparseTensorValue) -_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue) @tf_export("SparseTensorSpec") @@ -574,3 +574,68 @@ def is_sparse(x): `tf.compat.v1.SparseTensorValue`. """ return isinstance(x, (SparseTensor, SparseTensorValue)) + + +# Conversion table for __truediv__. None entries mean no conversion required. +_TRUEDIV_TABLE = { + dtypes.uint8: dtypes.float32, + dtypes.int8: dtypes.float32, + dtypes.uint16: dtypes.float32, + dtypes.int16: dtypes.float32, + dtypes.uint32: dtypes.float64, + dtypes.int32: dtypes.float64, + dtypes.uint64: dtypes.float64, + dtypes.int64: dtypes.float64, + dtypes.bfloat16: None, + dtypes.float16: None, + dtypes.float32: None, + dtypes.float64: None, + dtypes.complex64: None, + dtypes.complex128: None, +} + + +# NOTE: the support of "sparse (true)div dense" is currently not baked in into +# "tf.(true_)div()". Until such an API decision is made, the supported usage is +# to explicitly use the "/" operator to invoke either truediv or div. +def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None): + """Internal helper function for 'sp_t / dense_t'.""" + with ops.name_scope( + name, "truediv", [sp_indices, sp_values, sp_shape, y] + ) as name: + sp_values = ops.convert_to_tensor(sp_values, name="sp_values") + y = ops.convert_to_tensor(y, name="y") + x_dtype = sp_values.dtype.base_dtype + y_dtype = y.dtype.base_dtype + if x_dtype != y_dtype: + raise TypeError( + "`x` and `y` must have the same dtype, " + f"got {x_dtype!r} != {y_dtype!r}." + ) + try: + dtype = _TRUEDIV_TABLE[x_dtype] + except KeyError as exc: + raise TypeError( + f"Invalid dtype {x_dtype!r} in __truediv__. Expected one " + f"of {{{', '.join([repr(x) for x in _TRUEDIV_TABLE.keys()])}}}." + ) from exc + if dtype is not None: + sp_values = gen_math_ops.cast(sp_values, dtype) + y = gen_math_ops.cast(y, dtype) + return gen_sparse_ops.sparse_dense_cwise_div( + sp_indices, sp_values, sp_shape, y, name=name + ) + + +# NOTE(aselle): When integer division is added for sparse_dense_cwise, +# div, truediv, and floordiv should be delegated appropriately for +# Python semantics, analogous to dense cwise tensor operations. +override_binary_operator.override_binary_operator_helper( + gen_sparse_ops.sparse_dense_cwise_div, "div", SparseTensor +) # pylint: disable=protected-access +override_binary_operator.override_binary_operator_helper( + _sparse_dense_truediv, "truediv", SparseTensor +) # pylint: disable=protected-access +override_binary_operator.override_binary_operator_helper( + gen_sparse_ops.sparse_dense_cwise_mul, "mul", SparseTensor +) # pylint: disable=protected-access diff --git a/tensorflow/python/framework/stack.py b/tensorflow/python/framework/stack.py index 5a1e8fbd1311fd..a91fc99be530e9 100644 --- a/tensorflow/python/framework/stack.py +++ b/tensorflow/python/framework/stack.py @@ -14,39 +14,43 @@ # ============================================================================== """Classes used to handle thread-local stacks.""" +from collections.abc import Iterator import threading +from typing import Generic, Optional, TypeVar from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export +T = TypeVar("T") -class DefaultStack(threading.local): + +class DefaultStack(threading.local, Generic[T]): """A thread-local stack of objects for providing implicit defaults.""" def __init__(self): super().__init__() self._enforce_nesting = True - self.stack = [] + self.stack: list[T] = [] - def get_default(self): + def get_default(self) -> Optional[T]: return self.stack[-1] if self.stack else None - def reset(self): + def reset(self) -> None: self.stack = [] - def is_cleared(self): + def is_cleared(self) -> bool: return not self.stack @property - def enforce_nesting(self): + def enforce_nesting(self) -> bool: return self._enforce_nesting @enforce_nesting.setter - def enforce_nesting(self, value): + def enforce_nesting(self, value: bool): self._enforce_nesting = value @tf_contextlib.contextmanager - def get_controller(self, default): + def get_controller(self, default: T) -> Iterator[T]: """A context manager for manipulating a default stack.""" self.stack.append(default) try: diff --git a/tensorflow/python/framework/tensor.py b/tensorflow/python/framework/tensor.py index 5fa83194866e8d..823d31d38eeb12 100644 --- a/tensorflow/python/framework/tensor.py +++ b/tensorflow/python/framework/tensor.py @@ -40,7 +40,6 @@ from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.types import core as core_tf_types from tensorflow.python.types import internal -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import object_identity @@ -1455,7 +1454,6 @@ def do_decode(self, value, decode_fn): nested_structure_coder.register_codec(_BoundedTensorSpecCodec()) trace_type.register_serializable(BoundedTensorSpec) -_pywrap_utils.RegisterType("TensorSpec", TensorSpec) # Note: we do not include Tensor names when constructing TypeSpecs. type_spec.register_type_spec_from_value_converter( diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 888b681cee369a..4c982e87873f09 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -18,7 +18,7 @@ import collections from collections import OrderedDict -from collections.abc import Iterator +from collections.abc import Iterable, Iterator, Callable, Collection, Sequence import contextlib import functools import gc @@ -30,16 +30,21 @@ import tempfile import threading import time -from typing import Union +from typing import Any, cast, Union, Optional, overload, TypeVar import unittest from absl.testing import parameterized import numpy as np from google.protobuf import descriptor_pool +from google.protobuf import message from google.protobuf import text_format from tensorflow.core.config import flags from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import tensor_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import pywrap_sanitizers from tensorflow.python import tf2 @@ -99,13 +104,19 @@ from tensorflow.python.util.tf_export import tf_export +_F = TypeVar("_F", bound=Callable[..., Any]) +_T = TypeVar("_T") +_TC = TypeVar("_TC", bound=type["TensorFlowTestCase"]) + + # If the below import is made available through the BUILD rule, then this # function is overridden and will instead return True and cause Tensorflow # graphs to be compiled with XLA. -def is_xla_enabled(): +def is_xla_enabled() -> bool: return False +# pytype: disable=import-error try: from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import except Exception: # pylint: disable=broad-except @@ -114,7 +125,7 @@ def is_xla_enabled(): # Uses the same mechanism as above to selectively enable/disable MLIR # compilation. -def is_mlir_bridge_enabled(): +def is_mlir_bridge_enabled() -> Optional[bool]: return None @@ -125,36 +136,39 @@ def is_mlir_bridge_enabled(): from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import except ImportError: pass +# pytype: enable=import-error -def is_asan_enabled(): +def is_asan_enabled() -> bool: """Check if ASAN is enabled.""" return pywrap_sanitizers.is_asan_enabled() -def is_msan_enabled(): +def is_msan_enabled() -> bool: """Check if MSAN is enabled.""" return pywrap_sanitizers.is_msan_enabled() -def is_tsan_enabled(): +def is_tsan_enabled() -> bool: """Check if TSAN is enabled.""" return pywrap_sanitizers.is_tsan_enabled() -def is_ubsan_enabled(): +def is_ubsan_enabled() -> bool: """Check if UBSAN is enabled.""" return pywrap_sanitizers.is_ubsan_enabled() -def _get_object_count_by_type(exclude=()): +def _get_object_count_by_type( + exclude: Iterable[Any] = (), +) -> collections.Counter[str]: return ( collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - collections.Counter([type(obj).__name__ for obj in exclude])) @tf_export("test.gpu_device_name") -def gpu_device_name(): +def gpu_device_name() -> str: """Returns the name of a GPU device if available or a empty string. This method should only be used in tests written with `tf.test.TestCase`. @@ -175,7 +189,9 @@ def gpu_device_name(): return "" -def assert_ops_in_graph(expected_ops, graph): +def assert_ops_in_graph( + expected_ops: dict[str, str], graph: ops.Graph +) -> dict[str, node_def_pb2.NodeDef]: """Assert all expected operations are found. Args: @@ -188,8 +204,8 @@ def assert_ops_in_graph(expected_ops, graph): Raises: ValueError: If the expected ops are not present in the graph. """ - actual_ops = {} - gd = graph.as_graph_def() + actual_ops: dict[str, node_def_pb2.NodeDef] = {} + gd = cast(graph_pb2.GraphDef, graph.as_graph_def()) for node in gd.node: if node.name in expected_ops: if expected_ops[node.name] != node.op: @@ -203,7 +219,9 @@ def assert_ops_in_graph(expected_ops, graph): @tf_export("test.assert_equal_graph_def", v1=[]) -def assert_equal_graph_def_v2(expected, actual): +def assert_equal_graph_def_v2( + expected: graph_pb2.GraphDef, actual: graph_pb2.GraphDef +) -> None: """Asserts that two `GraphDef`s are (mostly) the same. Compares two `GraphDef` protos for equality, ignoring versions and ordering of @@ -224,8 +242,12 @@ def assert_equal_graph_def_v2(expected, actual): @tf_export(v1=["test.assert_equal_graph_def"]) -def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, - hash_table_shared_name=False): +def assert_equal_graph_def_v1( + actual: graph_pb2.GraphDef, + expected: graph_pb2.GraphDef, + checkpoint_v2: bool = False, + hash_table_shared_name: bool = False +) -> None: """Asserts that two `GraphDef`s are (mostly) the same. Compares two `GraphDef` protos for equality, ignoring versions and ordering of @@ -248,8 +270,12 @@ def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, hash_table_shared_name) -def assert_equal_graph_def(actual, expected, checkpoint_v2=False, - hash_table_shared_name=False): +def assert_equal_graph_def( + actual: graph_pb2.GraphDef, + expected: graph_pb2.GraphDef, + checkpoint_v2: bool = False, + hash_table_shared_name: bool = False +)-> None: if not isinstance(actual, graph_pb2.GraphDef): raise TypeError("Expected tf.GraphDef for actual, got %s" % type(actual).__name__) @@ -271,7 +297,11 @@ def assert_equal_graph_def(actual, expected, checkpoint_v2=False, raise AssertionError(compat.as_str(diff)) -def assert_meta_graph_protos_equal(tester, a, b): +def assert_meta_graph_protos_equal( + tester: "TensorFlowTestCase", + a: meta_graph_pb2.MetaGraphDef, + b: meta_graph_pb2.MetaGraphDef, +) -> None: """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" # Carefully check the collection_defs tester.assertEqual(set(a.collection_def), set(b.collection_def)) @@ -279,7 +309,7 @@ def assert_meta_graph_protos_equal(tester, a, b): for k in collection_keys: a_value = a.collection_def[k] b_value = b.collection_def[k] - proto_type = ops.get_collection_proto_type(k) + proto_type = cast(type[message.Message], ops.get_collection_proto_type(k)) if proto_type: a_proto = proto_type() b_proto = proto_type() @@ -315,11 +345,12 @@ def assert_meta_graph_protos_equal(tester, a, b): _SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" -def _strip_checkpoint_v2_randomized(graph_def): +def _strip_checkpoint_v2_randomized(graph_def: graph_pb2.GraphDef) -> None: for node in graph_def.node: - delete_keys = [] + delete_keys: list[str] = [] for attr_key in node.attr: - attr_tensor_value = node.attr[attr_key].tensor + attr_tensor_value = cast( + tensor_pb2.TensorProto, node.attr[attr_key].tensor) if attr_tensor_value and len(attr_tensor_value.string_val) == 1: attr_tensor_string_value = attr_tensor_value.string_val[0] if (attr_tensor_string_value and @@ -333,9 +364,9 @@ def _strip_checkpoint_v2_randomized(graph_def): _TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" -def _strip_hash_table_shared_name(graph_def): +def _strip_hash_table_shared_name(graph_def: graph_pb2.GraphDef) -> None: for node in graph_def.node: - delete_keys = [] + delete_keys: list[str] = [] if node.op == "HashTableV2" and "shared_name" in node.attr: if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), node.attr["shared_name"].s): @@ -344,35 +375,37 @@ def _strip_hash_table_shared_name(graph_def): del node.attr[attr_key] -def IsGoogleCudaEnabled(): +def IsGoogleCudaEnabled() -> bool: return _pywrap_util_port.IsGoogleCudaEnabled() -def IsBuiltWithROCm(): +def IsBuiltWithROCm() -> bool: return _pywrap_util_port.IsBuiltWithROCm() -def IsBuiltWithXLA(): +def IsBuiltWithXLA() -> bool: return _pywrap_util_port.IsBuiltWithXLA() -def IsBuiltWithNvcc(): +def IsBuiltWithNvcc() -> bool: return _pywrap_util_port.IsBuiltWithNvcc() -def GpuSupportsHalfMatMulAndConv(): +def GpuSupportsHalfMatMulAndConv() -> bool: return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() -def IsMklEnabled(): +def IsMklEnabled() -> bool: return _pywrap_util_port.IsMklEnabled() -def InstallStackTraceHandler(): +def InstallStackTraceHandler() -> None: _pywrap_stacktrace_handler.InstallStacktraceHandler() -def NHWCToNCHW(input_tensor): +def NHWCToNCHW( + input_tensor: Union[tensor_lib.Tensor, list[int]] +) -> Union[tensor_lib.Tensor, list[int]]: """Converts the input from the NHWC format to NCHW. Args: @@ -391,7 +424,9 @@ def NHWCToNCHW(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] -def NHWCToNCHW_VECT_C(input_shape_or_tensor): +def NHWCToNCHW_VECT_C( + input_shape_or_tensor: Union[tensor_lib.Tensor, list[int]] +)-> Union[tensor_lib.Tensor, list[int]]: """Transforms the input from the NHWC layout to NCHW_VECT_C layout. Note: Does not include quantization or type conversion steps, which should @@ -409,7 +444,7 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): """ permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) - temp_shape = ( + temp_shape: list[int] = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) if temp_shape[-1] % 4 != 0: @@ -426,7 +461,9 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): return [temp_shape[a] for a in permutation] -def NCHW_VECT_CToNHWC(input_shape_or_tensor): +def NCHW_VECT_CToNHWC( + input_shape_or_tensor: Union[tensor_lib.Tensor, list[int]] +) -> Union[tensor_lib.Tensor, list[int]]: """Transforms the input from the NCHW_VECT_C layout to NHWC layout. Note: Does not include de-quantization or type conversion steps, which should @@ -443,7 +480,7 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): """ permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} is_tensor = isinstance(input_shape_or_tensor, tensor_lib.Tensor) - input_shape = ( + input_shape: list[int] = ( input_shape_or_tensor.shape.as_list() if is_tensor else input_shape_or_tensor) if input_shape[-1] != 4: @@ -458,7 +495,9 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): return nhwc_shape -def NCHWToNHWC(input_tensor): +def NCHWToNHWC( + input_tensor: Union[tensor_lib.Tensor, list[int]] +) -> Union[tensor_lib.Tensor, list[int]]: """Converts the input from the NCHW format to NHWC. Args: @@ -477,7 +516,7 @@ def NCHWToNHWC(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] -def skip_if(condition): +def skip_if(condition: Union[Callable[[], bool], bool]) -> Callable[[_F], _F]: """Skips the decorated function if condition is or evaluates to True. Args: @@ -488,7 +527,7 @@ def skip_if(condition): The wrapped function """ - def real_skip_if(fn): + def real_skip_if(fn: _F) -> _F: def wrapper(*args, **kwargs): if callable(condition): @@ -504,7 +543,11 @@ def wrapper(*args, **kwargs): @contextlib.contextmanager -def skip_if_error(test_obj, error_type, messages=None): +def skip_if_error( + test_obj: unittest.TestCase, + error_type: type[Exception], + messages: Union[str, list[str], None] = None +) -> Iterator[None]: """Context manager to skip cases not considered failures by the tests. Note that this does not work if used in setUpClass/tearDownClass. @@ -535,17 +578,17 @@ def skip_if_error(test_obj, error_type, messages=None): raise -def enable_c_shapes(fn): +def enable_c_shapes(fn: _F) -> _F: """No-op. TODO(b/74620627): Remove this.""" return fn -def with_c_shapes(cls): +def with_c_shapes(cls: type[_T]) -> type[_T]: """No-op. TODO(b/74620627): Remove this.""" return cls -def enable_control_flow_v2(fn): +def enable_control_flow_v2(fn: _F) -> _F: """Decorator for enabling CondV2 and WhileV2 on a test. Note this enables using CondV2 and WhileV2 after running the test class's @@ -572,7 +615,7 @@ def wrapper(*args, **kwargs): return wrapper -def with_control_flow_v2(cls): +def with_control_flow_v2(cls: _TC) -> _TC: """Adds methods that call original methods with WhileV2 and CondV2 enabled. Note this enables CondV2 and WhileV2 in new methods after running the test @@ -627,7 +670,7 @@ def testDisabledForV2(self): return cls -def disable_control_flow_v2(unused_msg): +def disable_control_flow_v2(unused_msg: str) -> Callable[[_F], _F]: """Decorator for a function in a with_control_flow_v2 enabled test class. Blocks the function from being run with v2 control flow ops. @@ -639,14 +682,14 @@ def disable_control_flow_v2(unused_msg): The wrapped function with _disable_control_flow_v2 attr set to True. """ - def wrapper(func): + def wrapper(func: _F) -> _F: func._disable_control_flow_v2 = True return func return wrapper -def enable_output_all_intermediates(fn): +def enable_output_all_intermediates(fn: _F) -> _F: """Force-enable outputing all intermediates from functional control flow ops. Args: @@ -669,26 +712,27 @@ def wrapper(*args, **kwargs): return wrapper -def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): +def assert_no_new_pyobjects_executing_eagerly( + warmup_iters: int = 2, +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Decorator for asserting that no new Python objects persist after a test. - Runs the test multiple times executing eagerly, first as a warmup and then to - let objects accumulate. The warmup helps ignore caches which do not grow as - the test is run repeatedly. + Returns a decorator that runs the test multiple times executing eagerly, + first as a warmup and then to let objects accumulate. The warmup helps ignore + caches which do not grow as the test is run repeatedly. Useful for checking that there are no missing Py_DECREFs in the C exercised by a bit of Python. Args: - func: The function to test. warmup_iters: The numer of warmup iterations, excluded from measuring. Returns: - The wrapped function performing the test. + A decorator function which can be applied to the test function. """ - def wrap_f(f): - def decorator(self, *args, **kwargs): + def wrap_f(f: Callable[..., Any]) -> Callable[..., None]: + def decorator(self: "TensorFlowTestCase", *args, **kwargs) -> None: """Warms up, gets object counts, runs the test, checks for new objects.""" with context.eager_mode(): gc.disable() @@ -780,15 +824,12 @@ def decorator(self, *args, **kwargs): "The following objects were newly created: %s" % str(obj_count_by_type)) gc.enable() - return decorator + return tf_decorator.make_decorator(f, decorator) - if func is None: - return wrap_f - else: - return wrap_f(func) + return wrap_f -def assert_no_new_tensors(f): +def assert_no_new_tensors(f: _F) -> _F: """Decorator for asserting that no new Tensors persist after a test. Mainly useful for checking that code using the Python C API has correctly @@ -807,10 +848,10 @@ def assert_no_new_tensors(f): The decorated test case. """ - def decorator(self, **kwargs): + def decorator(self: "TensorFlowTestCase", **kwargs): """Finds existing Tensors, runs the test, checks for new Tensors.""" - def _is_tensorflow_object(obj): + def _is_tensorflow_object(obj) -> bool: try: return isinstance(obj, (tensor_lib.Tensor, variables.Variable, @@ -821,7 +862,7 @@ def _is_tensorflow_object(obj): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - outside_executed_eagerly = context.executing_eagerly() + outside_executed_eagerly = cast(bool, context.executing_eagerly()) # Run the test in a new graph so that collections get cleared when it's # done, but inherit the graph key so optimizers behave. outside_graph_key = ops.get_default_graph()._graph_key @@ -847,12 +888,12 @@ def _is_tensorflow_object(obj): ))) return result - return decorator + return tf_decorator.make_decorator(f, decorator) -def _find_reference_cycle(objects, idx): +def _find_reference_cycle(objects: Sequence[Any], idx: int) -> bool: - def get_ignore_reason(obj, denylist): + def get_ignore_reason(obj: Any, denylist: Collection[Any]) -> Optional[str]: """Tests whether an object should be omitted from the dependency graph.""" if len(denylist) > 100: return "" @@ -869,7 +910,9 @@ def get_ignore_reason(obj, denylist): # Note: this function is meant to help with diagnostics. Its output is purely # a human-readable representation, so you may freely modify it to suit your # needs. - def describe(obj, denylist, leaves_only=False): + def describe( + obj: Any, denylist: Collection[Any], leaves_only: bool = False, + ) -> str: """Returns a custom human-readable summary of obj. Args: @@ -901,7 +944,12 @@ def describe(obj, denylist, leaves_only=False): else: return "{}, {}".format(type(obj), id(obj)) - def build_ref_graph(obj, graph, reprs, denylist): + def build_ref_graph( + obj: Any, + graph: dict[int, list[int]], + reprs: dict[int, str], + denylist: tuple[Any, ...], + ) -> None: """Builds a reference graph as -> . Args: @@ -927,7 +975,12 @@ def build_ref_graph(obj, graph, reprs, denylist): build_ref_graph(r, graph, reprs, denylist) reprs[r_id] = describe(r, denylist) - def find_cycle(el, graph, reprs, path): + def find_cycle( + el: int, + graph: dict[int, list[int]], + reprs: dict[int, str], + path: tuple[int, ...], + ) -> Optional[bool]: """Finds and prints a single cycle in the dependency graph.""" if el not in graph: return @@ -943,8 +996,8 @@ def find_cycle(el, graph, reprs, path): return False obj = objects[idx] - graph = {} # referrer ID -> object ID - reprs = {} # object ID -> description + graph: dict[int, list[int]] = {} # referrer ID -> object ID + reprs: dict[int, str] = {} # object ID -> description build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, describe, build_ref_graph, find_cycle)) for k in graph: @@ -953,7 +1006,7 @@ def find_cycle(el, graph, reprs, path): return False -def assert_no_garbage_created(f): +def assert_no_garbage_created(f: _F) -> _F: """Test method decorator to assert that no garbage has been created. Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters @@ -969,7 +1022,7 @@ def assert_no_garbage_created(f): # FIXME(power) -- Update documentation, we no longer care if garbage is # created, we only want to verify we don't have memory leaks. - def decorator(self, **kwargs): + def decorator(self: "TensorFlowTestCase", **kwargs): """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" gc.disable() previous_debug_flags = gc.get_debug() @@ -995,7 +1048,7 @@ def decorator(self, **kwargs): logging.error("Object %d of %d", i, len(gc.garbage) - previous_garbage) - def _safe_object_str(obj): + def _safe_object_str(obj) -> str: return "<%s %d>" % (obj.__class__.__name__, id(obj)) logging.error(" Object type: %s", _safe_object_str(obj)) @@ -1033,7 +1086,7 @@ def _safe_object_str(obj): return decorator -def _combine_named_parameters(**kwargs): +def _combine_named_parameters(**kwargs) -> list[OrderedDict[str, Any]]: """Generate combinations based on its keyword arguments. Two sets of returned combinations can be concatenated using +. Their product @@ -1049,7 +1102,7 @@ def _combine_named_parameters(**kwargs): corresponding keyword argument values. """ sort_by_key = lambda k: k[0] - combinations = [] + combinations: list[list[tuple[str, Any]]] = [] for key, values in sorted(kwargs.items(), key=sort_by_key): if not isinstance(values, list): values = [values] @@ -1058,7 +1111,9 @@ def _combine_named_parameters(**kwargs): return [OrderedDict(result) for result in itertools.product(*combinations)] -def generate_combinations_with_testcase_name(**kwargs): +def generate_combinations_with_testcase_name( + **kwargs, +) -> list[OrderedDict[str, Any]]: """Generate combinations based on its keyword arguments using combine(). This function calls combine() and appends a testcase name to the list of @@ -1075,7 +1130,7 @@ def generate_combinations_with_testcase_name(**kwargs): corresponding keyword argument values. """ combinations = _combine_named_parameters(**kwargs) - named_combinations = [] + named_combinations: list[OrderedDict[str, Any]] = [] for combination in combinations: assert isinstance(combination, OrderedDict) name = "".join([ @@ -1091,7 +1146,7 @@ def generate_combinations_with_testcase_name(**kwargs): return named_combinations -def run_all_in_graph_and_eager_modes(cls): +def run_all_in_graph_and_eager_modes(cls: _TC) -> _TC: """Execute all test methods in the given class with and without eager.""" base_decorator = run_in_graph_and_eager_modes for name in dir(cls): @@ -1107,7 +1162,7 @@ def run_all_in_graph_and_eager_modes(cls): return cls -def run_class_in_v1_v2(cls): +def run_class_in_v1_v2(cls: _TC) -> _TC: """Execute all test methods in a given class in v1 and v2 modes.""" base_decorator = run_in_v1_v2 for name in dir(cls): @@ -1126,7 +1181,7 @@ def run_class_in_v1_v2(cls): return cls -def enable_nested_function_shape_inference(fn): +def enable_nested_function_shape_inference(fn: _F) -> _F: """Decorator for enabling nested_function_shape_inference on a test. This function returns a decorator intended to be applied to test methods in @@ -1163,7 +1218,7 @@ def wrapper(*args, **kwargs): return wrapper -def enable_quantized_dtypes_training(fn): +def enable_quantized_dtypes_training(fn: _F) -> _F: """Decorator for enabling quantized_dtypes_training on a test. This function returns a decorator intended to be applied to test methods in @@ -1200,7 +1255,7 @@ def wrapper(*args, **kwargs): return wrapper -def enable_eager_op_as_function(fn): +def enable_eager_op_as_function(fn: _F) -> _F: """Returns the same fn. This will be removed once all usages are removed. Args: @@ -1216,8 +1271,27 @@ def wrapper(*args, **kwargs): return wrapper +@overload +def with_eager_op_as_function( + cls: type[_T], + only_as_function: bool = False, +) -> type[_T]: + ... + + +@overload +def with_eager_op_as_function( + cls: None = None, + only_as_function: bool = False, +) -> Callable[[type[_T]], type[_T]]: + ... + + @tf_export("test.with_eager_op_as_function") -def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disable=unused-argument +def with_eager_op_as_function( + cls: Optional[type[_T]] = None, + only_as_function: bool = False, # pylint: disable=unused-argument +) -> Union[Callable[[type[_T]], type[_T]], type[_T]]: """Returns the same class. This will be removed once all usages are removed. Args: @@ -1228,16 +1302,16 @@ def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disa cls """ - def decorator(cls): + def decorator(cls: type[_T]) -> type[_T]: return cls if cls is not None: return decorator(cls) - return decorator + return decorator # pytype: disable=bad-return-type -def enable_graph_building_optimization(fn): +def enable_graph_building_optimization(fn: _F) -> _F: """Decorator for enabling graph_building_optimization on a test. This function returns a decorator intended to be applied to test methods in @@ -1273,7 +1347,7 @@ def wrapper(*args, **kwargs): return wrapper -def add_graph_building_optimization_tests(cls=None): +def add_graph_building_optimization_tests(cls: _TC) -> _TC: """Adds methods with graph_building_optimization enabled to the test suite. Example: @@ -1302,25 +1376,19 @@ def testBarWithGraphBuildingOptimization(self): cls with new test methods added. """ - def decorator(cls): - if flags.config().graph_building_optimization.value(): - return cls - - for name, value in cls.__dict__.copy().items(): - if (callable(value) and - (name.startswith(unittest.TestLoader.testMethodPrefix) or - name.startswith("benchmark"))): - setattr(cls, name + "WithGraphBuildingOptimization", - enable_graph_building_optimization(value)) + if flags.config().graph_building_optimization.value(): return cls - if cls is not None: - return decorator(cls) - - return decorator + for name, value in cls.__dict__.copy().items(): + if (callable(value) and + (name.startswith(unittest.TestLoader.testMethodPrefix) or + name.startswith("benchmark"))): + setattr(cls, name + "WithGraphBuildingOptimization", + enable_graph_building_optimization(value)) + return cls -def disable_eager_op_as_function(unused_msg): +def disable_eager_op_as_function(unused_msg: str) -> Callable[[_F], _F]: """Decorator for a function in a with_eager_op_as_function enabled test class. Blocks the function from being run with eager_op_as_function enabled. @@ -1334,7 +1402,7 @@ def disable_eager_op_as_function(unused_msg): return _disable_test(execute_func=False) -def set_xla_env_flag(func=None, flag=""): +def set_xla_env_flag(flag: str = "") -> Callable[[_F], _F]: """Decorator for setting XLA_FLAGS prior to running a test. This function returns a decorator intended to be applied to test methods in @@ -1351,14 +1419,14 @@ def testFoo(self): ... Args: - func: The function to be wrapped. flag: The xla flag to be set in the XLA_FLAGS env variable. Returns: - The wrapped function. + A decorator which sets the configured flag in XLA_FLAGS for the decorated + function. """ - def decorator(f): + def decorator(f: _F) -> _F: @functools.wraps(f) def decorated(*args, **kwargs): @@ -1377,13 +1445,12 @@ def decorated(*args, **kwargs): return decorated - if func is not None: - return decorator(func) - return decorator -def build_as_function_and_v1_graph(func=None): +def build_as_function_and_v1_graph( + func: Callable[..., Any], +) -> Callable[..., None]: """Run a test case in v1 graph mode and inside tf.function in eager mode. WARNING: This decorator can only be used in test cases that statically checks @@ -1400,47 +1467,46 @@ def build_as_function_and_v1_graph(func=None): Decorated test case function. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError( - "`run_in_graph_mode_and_function` only supports test methods.") - - @parameterized.named_parameters(("_v1_graph", "v1_graph"), - ("_function", "function")) - @functools.wraps(f) - def decorated(self, run_mode, *args, **kwargs): - if run_mode == "v1_graph": - with ops.Graph().as_default(): - f(self, *args, **kwargs) - elif run_mode == "function": - - @def_function.function - def function_in_eager(): - f(self, *args, **kwargs) - - # Create a new graph for the eagerly executed version of this test for - # better isolation. - graph_for_eager_test = ops.Graph() - with graph_for_eager_test.as_default(), context.eager_mode(): - function_in_eager() - ops.dismantle_graph(graph_for_eager_test) - else: - raise ValueError("Unknown run mode %s" % run_mode) - - return decorated + if tf_inspect.isclass(func): + raise ValueError( + "`run_in_graph_mode_and_function` only supports test methods.") + + @parameterized.named_parameters(("_v1_graph", "v1_graph"), + ("_function", "function")) + @functools.wraps(func) + def decorated( + self: "TensorFlowTestCase", + run_mode: str, + *args, + **kwargs, + ) -> None: + if run_mode == "v1_graph": + with ops.Graph().as_default(): + func(self, *args, **kwargs) + elif run_mode == "function": + + @def_function.function + def function_in_eager(): + func(self, *args, **kwargs) - if func is not None: - return decorator(func) + # Create a new graph for the eagerly executed version of this test for + # better isolation. + graph_for_eager_test = ops.Graph() + with graph_for_eager_test.as_default(), context.eager_mode(): + function_in_eager() + ops.dismantle_graph(graph_for_eager_test) + else: + raise ValueError("Unknown run mode %s" % run_mode) - return decorator + return decorated -def run_in_async_and_sync_mode(f): +def run_in_async_and_sync_mode(f: _F) -> _F: """Execute the test in async mode and sync mode.""" @parameterized.named_parameters([("Async", True), ("", False)]) @functools.wraps(f) - def decorator(self, async_mode, *args, **kwargs): + def decorator(self: "TensorFlowTestCase", async_mode: bool, *args, **kwargs): if async_mode: with context.execution_mode(context.ASYNC): f(self, *args, **kwargs) @@ -1450,10 +1516,35 @@ def decorator(self, async_mode, *args, **kwargs): return decorator -def run_in_graph_and_eager_modes(func=None, - config=None, - use_gpu=True, - assert_no_eager_garbage=False): +@overload +def run_in_graph_and_eager_modes( + func: Callable[..., Any], + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + assert_no_eager_garbage: bool = False, +) -> Callable[..., None]: + ... + + +@overload +def run_in_graph_and_eager_modes( + func: None = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + assert_no_eager_garbage: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., None]]: + ... + + +def run_in_graph_and_eager_modes( + func: Optional[Callable[..., Any]] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + assert_no_eager_garbage: bool = False, +) -> Union[ + Callable[[Callable[..., Any]], Callable[..., None]], + Callable[..., None], +]: """Execute the decorated test with and without enabling eager execution. This function returns a decorator intended to be applied to test methods in @@ -1511,13 +1602,13 @@ def test_foo(self): eager execution enabled. """ - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: if tf_inspect.isclass(f): raise ValueError( "`run_in_graph_and_eager_modes` only supports test methods. " "Did you mean to use `run_all_in_graph_and_eager_modes`?") - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> None: logging.info("Running %s in GRAPH mode.", f.__name__) try: with context.graph_mode(), self.subTest("graph_mode"): @@ -1536,7 +1627,7 @@ def decorated(self, *args, **kwargs): except unittest.case.SkipTest: pass - def run_eagerly(self, **kwargs): + def run_eagerly(self: "TensorFlowTestCase", **kwargs) -> None: logging.info("Running %s in EAGER mode.", f.__name__) if not use_gpu: with ops.device("/device:CPU:0"): @@ -1573,17 +1664,15 @@ def run_eagerly(self, **kwargs): return decorator -def run_in_v1_v2(func=None, - device_to_use: str = None, - assert_no_eager_garbage: bool = False): +def run_in_v1_v2( + device_to_use: Optional[str] = None, + assert_no_eager_garbage: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Execute the decorated test in v1 and v2 modes. The overall execution is similar to that of `run_in_graph_and_eager_mode`. Args: - func: A test function/method to be decorated. If `func` is None, this method - returns a decorator the can be applied to a function. Otherwise, an - already applied decorator is returned. device_to_use: A string in the following format: "/device:CPU:0". assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage collector and asserts that no extra garbage has been created when running @@ -1600,14 +1689,13 @@ def run_in_v1_v2(func=None, A decorator that runs a given test in v1 and v2 modes. """ - decorator_tag = "wrapped_with_v1_v2_decorator" - if hasattr(func, decorator_tag): - # Already decorated with this very same decorator - return func - - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: + decorator_tag = "wrapped_with_v1_v2_decorator" + if hasattr(f, decorator_tag): + # Already decorated with this very same decorator + return f - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> None: logging.info("Running %s in V1 mode.", f.__name__) try: with self.subTest("V1_mode"): @@ -1616,7 +1704,7 @@ def decorated(self, *args, **kwargs): except unittest.case.SkipTest: pass - def run_v2(self, **kwargs): + def run_v2(self: "TensorFlowTestCase", **kwargs) -> None: logging.info("Running %s in V2 mode.", f.__name__) if device_to_use: with ops.device(device_to_use): @@ -1644,20 +1732,17 @@ def run_v2(self, **kwargs): tf_decorated.__dict__[decorator_tag] = True return tf_decorated - if func is not None: - return decorator(func) - return decorator -def py_func_if_in_function(f): +def py_func_if_in_function(f: _F) -> _F: def decorated(*args, **kwds): if not ops.inside_function(): return f(*args, **kwds) - tensor_args = [] - tensor_indices = [] + tensor_args: list[Union[tensor_lib.Tensor, variables.Variable]] = [] + tensor_indices: list[int] = [] for i, arg in enumerate(args): if isinstance(arg, (tensor_lib.Tensor, variables.Variable)): tensor_args.append(arg) @@ -1674,7 +1759,7 @@ def inner_f(*inner_tensor_args): return tf_decorator.make_decorator(f, decorated) -def also_run_as_tf_function(f): +def also_run_as_tf_function(f: Callable[..., Any]) -> Callable[..., None]: """Runs the decorated test twice--once as is, once inside a tf.function. This allows you to run a test both in eager execution and inside a @@ -1694,9 +1779,9 @@ def also_run_as_tf_function(f): tf.function. """ - def decorated(*args, **kwds): + def decorated(*args, **kwds) -> None: - def bound_f(): + def bound_f() -> None: f(*args, **kwds) with context.eager_mode(): @@ -1709,59 +1794,64 @@ def bound_f(): return decorated -def deprecated_graph_mode_only(func=None): +@overload +def deprecated_graph_mode_only(func: _F) -> _F: + ... + + +@overload +def deprecated_graph_mode_only(func: _TC) -> Optional[_TC]: + ... + + +def deprecated_graph_mode_only(func: Union[_TC, _F]) -> Union[_TC, _F]: """Execute the decorated test in graph mode. - This function returns a decorator intended to be applied to tests that are not - compatible with eager mode. When this decorator is applied, the test body will - be run in an environment where API calls construct graphs instead of executing - eagerly. + This is a decorator intended to be applied to tests that are not compatible + with eager mode. When this decorator is applied, the test body will be run in + an environment where API calls construct graphs instead of executing eagerly. `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and `run_in_graph_and_eager_modes` are available decorators for different v1/v2/eager/graph combinations. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function or class to be annotated. + If `func` is a function this returns the decorator applied to `func`. + If `func` is a unit test class this returns that class with the decorator + applied to all test functions within that class. Returns: - Returns a decorator that will run the decorated test method in graph mode. + Returns a function or class that will run the decorated test(s) + in graph mode. """ - def decorator(f): - if tf_inspect.isclass(f): - setup = f.__dict__.get("setUp") - if setup is not None: - setattr(f, "setUp", decorator(setup)) - - for name, value in f.__dict__.copy().items(): - if (callable(value) and - name.startswith(unittest.TestLoader.testMethodPrefix)): - setattr(f, name, decorator(value)) - - return f + if tf_inspect.isclass(func): + setup = func.__dict__.get("setUp") + if setup is not None: + setattr(func, "setUp", deprecated_graph_mode_only(setup)) - def decorated(self, *args, **kwargs): - if context.executing_eagerly(): - with context.graph_mode(): - return f(self, *args, **kwargs) - else: - return f(self, *args, **kwargs) + for name, value in func.__dict__.copy().items(): + if (callable(value) and + name.startswith(unittest.TestLoader.testMethodPrefix)): + setattr(func, name, deprecated_graph_mode_only(value)) - return decorated + return func - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + if context.executing_eagerly(): + with context.graph_mode(): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(func, decorated) run_deprecated_v1 = deprecated_graph_mode_only -def run_all_in_deprecated_graph_mode_only(cls): +def run_all_in_deprecated_graph_mode_only(cls: _TC) -> _TC: """Execute all tests in a class in graph mode.""" base_decorator = deprecated_graph_mode_only for name in dir(cls): @@ -1847,73 +1937,57 @@ def run_v2_only(func=None, reason=None): return _run_vn_only(func=func, v2=True, reason=reason) -def run_gpu_only(func=None): +def run_gpu_only(func: _F) -> _F: """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence of a GPU. If a GPU is absent, it will simply be skipped. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function to be annotated. Returns: - Returns a decorator that will conditionally skip the decorated test method. + Returns a function that will conditionally skip the decorated test method. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError("`run_gpu_only` only supports test methods.") - - def decorated(self, *args, **kwargs): - if not is_gpu_available(): - self.skipTest("Test requires GPU") + if tf_inspect.isclass(func): + raise ValueError("`run_gpu_only` only supports test methods.") - return f(self, *args, **kwargs) + def decorated(self: "TensorFlowTestCase", *args, **kwargs): + if not is_gpu_available(): + self.skipTest("Test requires GPU") - return decorated - - if func is not None: - return decorator(func) + return func(self, *args, **kwargs) - return decorator + return decorated -def run_cuda_only(func=None): +def run_cuda_only(func: _F) -> _F: """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function to be annotated. Returns: - Returns a decorator that will conditionally skip the decorated test method. + Returns a function that will conditionally skip the decorated test method. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError("`run_cuda_only` only supports test methods.") - - def decorated(self, *args, **kwargs): - if not is_gpu_available(cuda_only=True): - self.skipTest("Test requires CUDA GPU") + if tf_inspect.isclass(func): + raise ValueError("`run_cuda_only` only supports test methods.") - return f(self, *args, **kwargs) + def decorated(self: "TensorFlowTestCase", *args, **kwargs): + if not is_gpu_available(cuda_only=True): + self.skipTest("Test requires CUDA GPU") - return decorated - - if func is not None: - return decorator(func) + return func(self, *args, **kwargs) - return decorator + return decorated -def run_gpu_or_tpu(func=None): +def run_gpu_or_tpu(func: _F) -> _F: """Execute the decorated test only if a physical GPU or TPU is available. This function is intended to be applied to tests that require the presence @@ -1923,33 +1997,30 @@ def run_gpu_or_tpu(func=None): - If both GPU and TPU are absent, the test will be skipped. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function to be annotated. Returns: - Returns a decorator that will conditionally skip the decorated test method. + Returns a function that will conditionally skip the decorated test method. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError("`run_gpu_or_tpu` only supports test methods.") + if tf_inspect.isclass(func): + raise ValueError("`run_gpu_or_tpu` only supports test methods.") - def decorated(self, *args, **kwargs): - if config.list_physical_devices("GPU"): - return f(self, "GPU", *args, **kwargs) + def decorated(self: "TensorFlowTestCase", *args, **kwargs): + if config.list_physical_devices("GPU"): + return func(self, "GPU", *args, **kwargs) - if config.list_physical_devices("TPU"): - return f(self, "TPU", *args, **kwargs) + if config.list_physical_devices("TPU"): + return func(self, "TPU", *args, **kwargs) - self.skipTest("Test requires GPU or TPU") + self.skipTest("Test requires GPU or TPU") - return decorated - - return decorator if func is None else decorator(func) + return decorated -def with_forward_compatibility_horizons(*horizons): +def with_forward_compatibility_horizons( + *horizons: Optional[tuple[int, int, int]] +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Executes the decorated test with the specified forward-compat horizons. Args: @@ -1967,19 +2038,19 @@ def with_forward_compatibility_horizons(*horizons): (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): raise ValueError("Bad horizon value: %r" % horizon) - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: if tf_inspect.isclass(f): raise ValueError("`with_forward_compatibility_horizons` only " "supports test methods.") - def decorated(self, *args, **kwargs): + def decorated(*args, **kwargs): for horizon in horizons: if horizon is None: - f(self, *args, **kwargs) + f(*args, **kwargs) else: (year, month, day) = horizon with forward_compatibility_horizon(year, month, day): - f(self, *args, **kwargs) - return decorated + f(*args, **kwargs) + return tf_decorator.make_decorator(f, decorated) return decorator @@ -1987,7 +2058,10 @@ def decorated(self, *args, **kwargs): @deprecation.deprecated(None, "Use `tf.config.list_physical_devices('GPU')` instead.") @tf_export("test.is_gpu_available") -def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): +def is_gpu_available( + cuda_only: bool = False, + min_cuda_compute_capability: Optional[tuple[int, int]] = None, +) -> bool: """Returns whether TensorFlow can access a GPU. Warning: if a non-GPU version of the package is installed, the function would @@ -2043,7 +2117,7 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): @contextlib.contextmanager -def device(use_gpu): +def device(use_gpu: bool) -> Iterator[None]: """Uses gpu when requested and available.""" if use_gpu and is_gpu_available(): dev = "/device:GPU:0" @@ -2054,28 +2128,28 @@ def device(use_gpu): @contextlib.contextmanager -def use_gpu(): +def use_gpu() -> Iterator[None]: """Uses gpu when requested and available.""" with device(use_gpu=True): yield @contextlib.contextmanager -def force_gpu(): +def force_gpu() -> Iterator[None]: """Force the gpu to be used.""" with ops.device("/device:GPU:0"): yield @contextlib.contextmanager -def force_cpu(): +def force_cpu() -> Iterator[None]: """Force the cpu to be used.""" with ops.device("/device:CPU:0"): yield @contextlib.contextmanager -def deterministic_ops(): +def deterministic_ops() -> Iterator[None]: """Enables deterministic ops.""" try: config.enable_op_determinism() @@ -2087,10 +2161,10 @@ def deterministic_ops(): class CapturedWrites: """A utility class to load the captured writes made to a stream.""" - def __init__(self, capture_location): + def __init__(self, capture_location: str): self.capture_location = capture_location - def contents(self): + def contents(self) -> str: """Get the captured writes as a single string.""" with open(self.capture_location) as tmp_file: output_data = "".join(tmp_file.readlines()) @@ -2169,7 +2243,7 @@ def run(self, *args, **kwargs): raise -def disable_cudnn_autotune(func): +def disable_cudnn_autotune(func: _F) -> _F: """Disable autotuning during the call to this function. Some tests want to base assertions on a graph being isomorphic with a copy. @@ -2182,46 +2256,39 @@ def disable_cudnn_autotune(func): Decorated function. """ - def decorator(f): + def decorated(*args, **kwargs): + original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") + os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" + original_xla_flags = os.environ.get("XLA_FLAGS") + new_xla_flags = "--xla_gpu_autotune_level=0" + if original_xla_flags: + new_xla_flags = original_xla_flags + " " + new_xla_flags + os.environ["XLA_FLAGS"] = new_xla_flags - def decorated(self, *args, **kwargs): - original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") - os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" - original_xla_flags = os.environ.get("XLA_FLAGS") - new_xla_flags = "--xla_gpu_autotune_level=0" - if original_xla_flags: - new_xla_flags = original_xla_flags + " " + new_xla_flags - os.environ["XLA_FLAGS"] = new_xla_flags + result = func(*args, **kwargs) - result = f(self, *args, **kwargs) - - if (original_tf_cudnn_use_autotune is None): - del os.environ["TF_CUDNN_USE_AUTOTUNE"] - else: - os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune - if (original_xla_flags is None): - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = original_xla_flags - - return result - - return tf_decorator.make_decorator(func, decorated) + if (original_tf_cudnn_use_autotune is None): + del os.environ["TF_CUDNN_USE_AUTOTUNE"] + else: + os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune + if (original_xla_flags is None): + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = original_xla_flags - if func is not None: - return decorator(func) + return result - return decorator + return tf_decorator.make_decorator(func, decorated) # The description is just for documentation purposes. -def enable_tf_xla_constant_folding(description): +def enable_tf_xla_constant_folding(description: str) -> Callable[[_F], _F]: if not isinstance(description, str): raise ValueError("'description' should be string, got {}".format( type(description))) - def enable_tf_xla_constant_folding_impl(func): + def enable_tf_xla_constant_folding_impl(func: _F) -> _F: """Enable constant folding during the call to this function. Some tests fail without constant folding. @@ -2233,119 +2300,103 @@ def enable_tf_xla_constant_folding_impl(func): Decorated function. """ - def decorator(f): - - def decorated(self, *args, **kwargs): - original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() - pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) - result = f(self, *args, **kwargs) - pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) - return result - - return decorated - - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() + pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) + result = func(*args, **kwargs) + pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) + return result - return decorator + return tf_decorator.make_decorator(func, decorated) return enable_tf_xla_constant_folding_impl # Updates test function by selectively disabling it. -def _disable_test(execute_func): - - def disable_test_impl(func): +def _disable_test(execute_func: bool) -> Callable[[_F], _F]: - def decorator(func): + def disable_test_impl(func: _F) -> _F: - def decorated(self, *args, **kwargs): - if execute_func: - return func(self, *args, **kwargs) - - return tf_decorator.make_decorator(func, decorated) - - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + if execute_func: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(func, decorated) return disable_test_impl # The description is just for documentation purposes. -def disable_xla(description): # pylint: disable=unused-argument +def disable_xla(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if xla is not enabled.""" execute_func = not is_xla_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_mlir_bridge(description): # pylint: disable=unused-argument +def disable_mlir_bridge(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if MLIR bridge is not enabled.""" execute_func = not is_mlir_bridge_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_asan(description): # pylint: disable=unused-argument +def disable_asan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if ASAN is not enabled.""" execute_func = not is_asan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_msan(description): # pylint: disable=unused-argument +def disable_msan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if MSAN is not enabled.""" execute_func = not is_msan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_tsan(description): # pylint: disable=unused-argument +def disable_tsan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if TSAN is not enabled.""" execute_func = not is_tsan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_ubsan(description): # pylint: disable=unused-argument +def disable_ubsan(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """Execute the test method only if UBSAN is not enabled.""" execute_func = not is_ubsan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_tfrt(unused_description): +def disable_tfrt( + unused_description: str, # pylint: disable=unused-argument +) -> Callable[[Union[_TC, _F]], Union[_TC, _F, None]]: - def disable_tfrt_impl(cls_or_func): + def disable_tfrt_impl(cls_or_func: Union[_TC, _F]) -> Union[_TC, _F, None]: """Execute the test only if tfrt is not enabled.""" if tf_inspect.isclass(cls_or_func): if tfrt_utils.enabled(): return None else: - return cls_or_func + return cast(_TC, cls_or_func) else: - def decorator(func): - - def decorated(self, *args, **kwargs): - if tfrt_utils.enabled(): - return - else: - return func(self, *args, **kwargs) - - return decorated - - if cls_or_func is not None: - return decorator(cls_or_func) + func = cast(Callable[..., Any], cls_or_func) + def decorated(*args, **kwargs): + if tfrt_utils.enabled(): + return + else: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(cls_or_func, decorated) return disable_tfrt_impl -def for_all_test_methods(decorator, *args, **kwargs): +def for_all_test_methods( + decorator: Callable[..., Any], *args, **kwargs, +) -> Callable[[_TC], _TC]: """Generate class-level decorator from given method-level decorator. It is expected for the given decorator to take some arguments and return @@ -2360,7 +2411,7 @@ def for_all_test_methods(decorator, *args, **kwargs): decorator. """ - def all_test_methods_impl(cls): + def all_test_methods_impl(cls: _TC) -> _TC: """Apply decorator to all test methods in class.""" for name in dir(cls): value = getattr(cls, name) @@ -2373,44 +2424,39 @@ def all_test_methods_impl(cls): # The description is just for documentation purposes. -def no_xla_auto_jit(description): # pylint: disable=unused-argument +def no_xla_auto_jit(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument """This test is not intended to be run with XLA auto jit enabled.""" execute_func = not is_xla_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def xla_allow_fallback(description): # pylint: disable=unused-argument +def xla_allow_fallback(description: str) -> Callable[[_F], _F]: # pylint: disable=unused-argument - def xla_allow_fallback_impl(func): + def xla_allow_fallback_impl(func: _F) -> _F: """Allow fallback to TF even though testing xla.""" - def decorator(func): - - def decorated(self, *args, **kwargs): - if is_xla_enabled(): - # Update the global XLABuildOpsPassFlags to enable lazy compilation, - # which allows the compiler to fall back to TF classic. Remember the - # old value so that we can reset it. - old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) - result = func(self, *args, **kwargs) - pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) - return result - else: - return func(self, *args, **kwargs) - - return decorated - - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + if is_xla_enabled(): + # Update the global XLABuildOpsPassFlags to enable lazy compilation, + # which allows the compiler to fall back to TF classic. Remember the + # old value so that we can reset it. + old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) + result = func(*args, **kwargs) + pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) + return result + else: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(func, decorated) return xla_allow_fallback_impl # The description is just for documentation purposes. -def run_without_tensor_float_32(description): # pylint: disable=unused-argument +def run_without_tensor_float_32( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Execute test with TensorFloat-32 disabled. While almost every real-world deep learning model runs fine with @@ -2426,24 +2472,24 @@ def run_without_tensor_float_32(description): # pylint: disable=unused-argument Decorator which runs a test with TensorFloat-32 disabled. """ - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: @functools.wraps(f) - def decorated(self, *args, **kwargs): + def decorated(*args, **kwargs): allowed = config.tensor_float_32_execution_enabled() try: config.enable_tensor_float_32_execution(False) - f(self, *args, **kwargs) + f(*args, **kwargs) finally: config.enable_tensor_float_32_execution(allowed) - return decorated + return tf_decorator.make_decorator(f, decorated) return decorator # The description is just for documentation purposes. -def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument +def run_all_without_tensor_float_32(description: str) -> Callable[[_TC], _TC]: # pylint: disable=unused-argument """Execute all tests in a class with TensorFloat-32 disabled.""" return for_all_test_methods(run_without_tensor_float_32, description) @@ -2585,7 +2631,7 @@ def _ClearCachedSession(self): self._cached_session.close() self._cached_session = None - def get_temp_dir(self): + def get_temp_dir(self) -> str: """Returns a unique temporary directory for the test to use. If you call this method multiple times during in a test, it will return the @@ -2814,7 +2860,11 @@ def evaluate( # pylint: disable=redefined-outer-name @contextlib.contextmanager def session( - self, graph=None, config=None, use_gpu=True, force_gpu=False + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + force_gpu: bool = False, ) -> Iterator[s.Session]: """A context manager for a TensorFlow Session for use in executing tests. @@ -2859,11 +2909,13 @@ def testMyOperator(self): yield sess @contextlib.contextmanager - def cached_session(self, - graph=None, - config=None, - use_gpu=True, - force_gpu=False) -> Iterator[s.Session]: + def cached_session( + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + force_gpu: bool = False, + ) -> Iterator[s.Session]: """Returns a TensorFlow Session for use in executing tests. This method behaves differently than self.session(): for performance reasons @@ -2913,11 +2965,13 @@ def testMyOperator(self): @contextlib.contextmanager @deprecation.deprecated(None, "Use `self.session()` or " "`self.cached_session()` instead.") - def test_session(self, - graph=None, - config=None, - use_gpu=True, - force_gpu=False): + def test_session( + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + use_gpu: bool = True, + force_gpu: bool = False, + ) -> Iterator[s.Session]: """Use cached_session instead.""" if self.id().endswith(".test_session"): self.skipTest( @@ -2947,7 +3001,13 @@ class _CheckedThread(object): method. """ - def __init__(self, testcase, target, args=None, kwargs=None): + def __init__( + self, + testcase: "TensorFlowTestCase", + target: Callable[..., Any], + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + ): """Constructs a new instance of _CheckedThread. Args: @@ -2959,21 +3019,21 @@ def __init__(self, testcase, target, args=None, kwargs=None): """ self._testcase = testcase self._target = target - self._args = () if args is None else args - self._kwargs = {} if kwargs is None else kwargs + self._args: tuple[Any, ...] = () if args is None else args + self._kwargs: dict[str, Any] = {} if kwargs is None else kwargs self._thread = threading.Thread(target=self._protected_run) self._exception = None self._is_thread_joined = False - def _protected_run(self): + def _protected_run(self) -> None: """Target for the wrapper thread. Sets self._exception on failure.""" try: self._target(*self._args, **self._kwargs) except Exception as e: # pylint: disable=broad-except self._exception = e - def start(self): + def start(self) -> None: """Starts the thread's activity. This must be called at most once per _CheckedThread object. It arranges @@ -2981,7 +3041,7 @@ def start(self): """ self._thread.start() - def join(self): + def join(self) -> None: """Blocks until the thread terminates. Raises: @@ -2993,7 +3053,7 @@ def join(self): if self._exception is not None: self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) - def is_alive(self): + def is_alive(self) -> bool: """Returns whether the thread is alive. This method returns True just before the run() method starts @@ -3004,7 +3064,7 @@ def is_alive(self): """ return self._thread.is_alive() - def check_termination(self): + def check_termination(self) -> None: """Returns whether the checked thread was properly used and did terminate. Every checked thread should be "join"ed after starting, and before the @@ -3026,7 +3086,12 @@ def check_termination(self): else: self._testcase.fail("A checked thread was not joined.") - def checkedThread(self, target, args=None, kwargs=None): + def checkedThread( + self, + target: Callable[..., Any], + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + ) -> _CheckedThread: """Returns a Thread wrapper that asserts 'target' completes successfully. This method should be used to create all threads in test cases, as @@ -3648,8 +3713,13 @@ def assertRaisesWithPredicateMatch(self, exception_type, else: def predicate(e): - err_str = e.message if isinstance(e, errors.OpError) else str(e) - op = e.op if isinstance(e, errors.OpError) else None + if isinstance(e, errors.OpError): + e = cast(errors.OpError, e) + err_str = cast(str, e.message) + op = e.op + else: + err_str = str(e) + op = None while op is not None: err_str += "\nCaused by: " + op.name op = op._original_op # pylint: disable=protected-access @@ -3748,7 +3818,8 @@ def assertDictEqual(self, a, b, msg=None): def _GetPyList(self, a): """Converts `a` to a nested python list.""" if isinstance(a, ragged_tensor.RaggedTensor): - return self.evaluate(a).to_list() + a = cast(ragged_tensor_value.RaggedTensorValue, self.evaluate(a)) + return a.to_list() elif isinstance(a, tensor_lib.Tensor): a = self.evaluate(a) return a.tolist() if isinstance(a, np.ndarray) else a @@ -3802,7 +3873,9 @@ def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): # pylint: enable=invalid-name @contextlib.contextmanager - def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): + def _constrain_devices_and_set_default( + self, sess: s.Session, use_gpu: bool, force_gpu: bool, + ) -> Iterator[s.Session]: """Set the session and its graph to global default and constrain devices.""" if context.executing_eagerly(): yield None @@ -3822,10 +3895,17 @@ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): with sess.graph.device("/device:CPU:0"): yield sess - def _create_session(self, graph, config, force_gpu): + def _create_session( + self, + graph: Optional[ops.Graph], + config: Optional[config_pb2.ConfigProto], + force_gpu: bool, + ) -> s.Session: """See session() for details.""" - def prepare_config(config): + def prepare_config( + config: Optional[config_pb2.ConfigProto], + ) -> config_pb2.ConfigProto: """Returns a config for sessions. Args: @@ -3861,11 +3941,13 @@ def prepare_config(config): return ErrorLoggingSession(graph=graph, config=prepare_config(config)) - def _get_cached_session(self, - graph=None, - config=None, - force_gpu=False, - crash_if_inconsistent_args=True): + def _get_cached_session( + self, + graph: Optional[ops.Graph] = None, + config: Optional[config_pb2.ConfigProto] = None, + force_gpu: bool = False, + crash_if_inconsistent_args: bool = True, + ) -> s.Session: """See cached_session() for documentation.""" if self._cached_session is None: sess = self._create_session( @@ -3896,7 +3978,7 @@ def _get_cached_session(self, return self._cached_session -ASSIGNED_PORTS = set() +ASSIGNED_PORTS: set[int] = set() lock = threading.Lock() @@ -3919,11 +4001,13 @@ def pick_unused_port(): @tf_export("test.create_local_cluster") -def create_local_cluster(num_workers, - num_ps, - protocol="grpc", - worker_config=None, - ps_config=None): +def create_local_cluster( + num_workers: int, + num_ps: int, + protocol: str = "grpc", + worker_config: Optional[config_pb2.ConfigProto] = None, + ps_config: Optional[config_pb2.ConfigProto] = None, +) -> tuple[list[server_lib.Server], list[server_lib.Server]]: """Create and start local servers and return the associated `Server` objects. "PS" stands for "parameter server": a task responsible for storing and @@ -4006,7 +4090,9 @@ def create_local_cluster(num_workers, return workers, ps_servers -def get_node_def_from_graph(node_name, graph_def): +def get_node_def_from_graph( + node_name: str, graph_def: graph_pb2.GraphDef, +) -> Optional[node_def_pb2.NodeDef]: """Returns the `NodeDef` instance for given node name in the graph def. This method explores only the NodeDefs in `graph_def.node`. @@ -4024,7 +4110,7 @@ def get_node_def_from_graph(node_name, graph_def): return None -def set_producer_version(graph, producer_version): +def set_producer_version(graph: ops.Graph, producer_version: int) -> None: """Sets graph.graph_def_versions.producer to `producer_version`.""" # The C API doesn't expose altering GraphDefVersions. We can indirectly set # it via import_graph_def though. @@ -4089,7 +4175,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @contextlib.contextmanager -def run_functions_eagerly(run_eagerly): +def run_functions_eagerly(run_eagerly: bool) -> Iterator[None]: """Runs functions eagerly if `run_eagerly` is true. WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode @@ -4134,17 +4220,17 @@ def __init__(self, name, label): self.label = label self.Reset() - def Reset(self): + def Reset(self) -> None: self.last_value = _test_metrics_util.test_counter_value( self.name, self.label) - def Get(self): + def Get(self) -> int: value = _test_metrics_util.test_counter_value(self.name, self.label) return value - self.last_value @tf_export("test.experimental.sync_devices") -def sync_devices(): +def sync_devices() -> None: """Synchronizes all devices. By default, GPUs run asynchronously. This means that when you run an op on the diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 69680857a3b037..1407aa328b7056 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -1196,11 +1196,11 @@ def __init__(self, *args, **kwargs): self.accumulation = [] @unittest.expectedFailure - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def test_has_leak(self): self.accumulation.append([1.]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def test_has_no_leak(self): self.not_accumulating = [1.] diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py index bce16048a24983..8a1fde77e6d506 100644 --- a/tensorflow/python/framework/traceable_stack.py +++ b/tensorflow/python/framework/traceable_stack.py @@ -14,21 +14,32 @@ # ============================================================================== """A simple stack that associates filename and line numbers with each object.""" +from collections.abc import Iterator import inspect +import types +from typing import cast, Generic, Optional, TypeVar -class TraceableObject(object): +T = TypeVar("T") + + +class TraceableObject(Generic[T]): """Wrap an object together with its the code definition location.""" # Return codes for the set_filename_and_line_from_caller() method. SUCCESS, HEURISTIC_USED, FAILURE = (0, 1, 2) - def __init__(self, obj, filename=None, lineno=None): + def __init__( + self, + obj: T, + filename: Optional[str] = None, + lineno: Optional[int] = None, + ): self.obj = obj self.filename = filename self.lineno = lineno - def set_filename_and_line_from_caller(self, offset=0): + def set_filename_and_line_from_caller(self, offset: int = 0) -> int: """Set filename and line using the caller's stack frame. If the requested stack information is not available, a heuristic may @@ -49,6 +60,9 @@ def set_filename_and_line_from_caller(self, offset=0): """ retcode = self.SUCCESS frame = inspect.currentframe() + if not frame: + return self.FAILURE + frame = cast(types.FrameType, frame) # Offset is defined in "Args" as relative to the caller. We are one frame # beyond the caller. for _ in range(offset + 1): @@ -57,9 +71,10 @@ def set_filename_and_line_from_caller(self, offset=0): # If the offset is too large then we use the largest offset possible. retcode = self.HEURISTIC_USED break + parent = cast(types.FrameType, parent) frame = parent self.filename = frame.f_code.co_filename - self.lineno = frame.f_lineno + self.lineno = cast(int, frame.f_lineno) return retcode def copy_metadata(self): @@ -67,19 +82,22 @@ def copy_metadata(self): return self.__class__(None, filename=self.filename, lineno=self.lineno) -class TraceableStack(object): +class TraceableStack(Generic[T]): """A stack of TraceableObjects.""" - def __init__(self, existing_stack=None): + def __init__( + self, existing_stack: Optional[list[TraceableObject[T]]] = None, + ): """Constructor. Args: existing_stack: [TraceableObject, ...] If provided, this object will set its new stack to a SHALLOW COPY of existing_stack. """ - self._stack = existing_stack[:] if existing_stack else [] + self._stack: list[TraceableObject[T]] = (existing_stack[:] if existing_stack + else []) - def push_obj(self, obj, offset=0): + def push_obj(self, obj: T, offset: int = 0): """Add object to the stack and record its filename and line information. Args: @@ -98,27 +116,27 @@ def push_obj(self, obj, offset=0): # beyond the caller and need to compensate. return traceable_obj.set_filename_and_line_from_caller(offset + 1) - def pop_obj(self): + def pop_obj(self) -> T: """Remove last-inserted object and return it, without filename/line info.""" return self._stack.pop().obj - def peek_top_obj(self): + def peek_top_obj(self) -> T: """Return the most recent stored object.""" return self._stack[-1].obj - def peek_objs(self): + def peek_objs(self) -> Iterator[T]: """Return iterator over stored objects ordered newest to oldest.""" return (t_obj.obj for t_obj in reversed(self._stack)) - def peek_traceable_objs(self): + def peek_traceable_objs(self) -> Iterator[TraceableObject[T]]: """Return iterator over stored TraceableObjects ordered newest to oldest.""" return reversed(self._stack) - def __len__(self): + def __len__(self) -> int: """Return number of items on the stack, and used for truth-value testing.""" return len(self._stack) - def copy(self): + def copy(self) -> "TraceableStack[T]": """Return a copy of self referencing the same objects but in a new list. This method is implemented to support thread-local stacks. diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index 26911cdd97bc80..beb278b1624f2c 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -33,7 +33,6 @@ from tensorflow.python.types import core as core_types from tensorflow.python.types import internal from tensorflow.python.types import trace -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -1057,6 +1056,3 @@ def register_type_spec_from_value_converter(type_object, _, type_object = tf_decorator.unwrap(type_object) _TYPE_CONVERSION_FUNCTION_REGISTRY.append( (type_object, converter_fn, allow_subclass)) - - -_pywrap_utils.RegisterType("TypeSpec", TypeSpec) diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 0a19f8fbcf89c3..366ebfa1927674 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -227,7 +227,7 @@ cuda_py_strict_test( size = "small", srcs = ["cluster_test.py"], python_version = "PY3", - shard_count = 10, + shard_count = 5, tags = [ "grappler", "no_pip", # tf_optimizer is not available in pip. diff --git a/tensorflow/python/grappler/remapper_test.py b/tensorflow/python/grappler/remapper_test.py index 6d693431f60ea4..91f283c5969792 100644 --- a/tensorflow/python/grappler/remapper_test.py +++ b/tensorflow/python/grappler/remapper_test.py @@ -227,6 +227,8 @@ def test_conv2d_biasadd_act_fusion(self): """Test Conv2D+BiasAdd+Relu fusion.""" if not test_util.is_gpu_available(): self.skipTest('No GPU available') + if test.is_built_with_rocm(): + self.skipTest('ROCm does not support conv biasadd fusion') N, H, W, C = (5, 3, 3, 8) # pylint: disable=invalid-name # The runtime fusion requires the output dims to be 32-bit aligned. diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index c9be0a3cc5ba10..fe1a5022c5296f 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -52,6 +52,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import tensor_getitem_override from tensorflow.python.ops import variable_scope from tensorflow.python.ops.ragged import ragged_getitem from tensorflow.python.ops.ragged import ragged_tensor @@ -1559,7 +1560,7 @@ def handle(self, args, kwargs): return self.NOT_SUPPORTED for slicing_op in [ - array_ops._slice_helper, # pylint: disable=protected-access + tensor_getitem_override._slice_helper, # pylint: disable=protected-access array_ops.boolean_mask, array_ops.boolean_mask_v2, ragged_getitem.ragged_tensor_getitem diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 4852a3c1768527..80cb2b53072a28 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -469,9 +469,11 @@ cuda_py_strict_test( cuda_py_strict_test( name = "manip_ops_test", - size = "small", + size = "medium", srcs = ["manip_ops_test.py"], - tags = ["no_windows_gpu"], + tags = [ + "no_windows_gpu", + ], deps = [ "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 0a3e6a2eb29b0c..f2cd5d0fd2afef 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import tensor_getitem_override from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_v1 from tensorflow.python.ops import variables @@ -688,7 +689,7 @@ def testInt64GPU(self): s = array_ops.strided_slice(x, begin, end, strides) self.assertAllEqual([3.], self.evaluate(s)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() @test_util.assert_no_garbage_created def testTensorSliceEagerMemory(self): with context.eager_mode(): @@ -697,7 +698,7 @@ def testTensorSliceEagerMemory(self): # Tests that slicing an EagerTensor doesn't leak memory inputs[0] # pylint: disable=pointless-statement - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() @test_util.assert_no_garbage_created def testVariableSliceEagerMemory(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): @@ -788,7 +789,7 @@ def testTensorIndexing(self): def testTensorIndexingTypeError(self): with self.session(): checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR) - expected = re.escape(array_ops._SLICE_TYPE_ERROR) + expected = re.escape(tensor_getitem_override._SLICE_TYPE_ERROR) with self.assertRaisesRegex(TypeError, expected): _ = checker["foo"] with self.assertRaisesRegex(TypeError, expected): diff --git a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py index 5fb4fb659d8f19..55cb3e049c0b32 100644 --- a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py +++ b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py @@ -208,7 +208,7 @@ def testExplicitShapeNumPy(self): shape=[2, 3, 5]) self.assertEqual(c.get_shape(), [2, 3, 5]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerMemory(self): """Tests PyObject refs are managed correctly when executing eagerly.""" constant_op.constant([[1.]]) diff --git a/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py b/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py index 65291165da8827..35e2c3c0f86e36 100644 --- a/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py @@ -105,11 +105,25 @@ def testEmptyInput(self): self._testAll(np.zeros([0, 1]), 1, 1) self._testAll(np.zeros([1, 0]), 1, 1) + @test_util.run_v2_only + def testLargeInput(self): + with test_util.force_cpu(): + # Num elements just over INT_MAX for int32 to ensure no overflow + np_input = np.arange(0, 128 * 524289 * 33, dtype=np.int8).reshape( + 128, -1, 33 + ) + + for shift in range(-5, 5): + roll = manip_ops.roll(np_input, shift, 0) + self.assertAllEqual(roll[shift], np_input[0], msg=f"shift={shift}") + self.assertAllEqual(roll[0], np_input[-shift], msg=f"shift={shift}") + @test_util.run_deprecated_v1 def testInvalidInputShape(self): # The input should be 1-D or higher, checked in shape function. - with self.assertRaisesRegex(ValueError, - "Shape must be at least rank 1 but is rank 0"): + with self.assertRaisesRegex( + ValueError, "Shape must be at least rank 1 but is rank 0" + ): manip_ops.roll(7, 1, 0) @test_util.run_deprecated_v1 diff --git a/tensorflow/python/kernel_tests/image_ops/BUILD b/tensorflow/python/kernel_tests/image_ops/BUILD index 93bef64e928122..c63a8bccd5d5e9 100644 --- a/tensorflow/python/kernel_tests/image_ops/BUILD +++ b/tensorflow/python/kernel_tests/image_ops/BUILD @@ -1,6 +1,10 @@ # Tests of TensorFlow image kernels written using the Python API. load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") +load( + "//tensorflow/tools/test:performance.bzl", + "tf_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -64,7 +68,7 @@ tf_py_strict_test( ], ) -tf_py_strict_test( +tf_py_benchmark_test( name = "decode_jpeg_op_test", srcs = ["decode_jpeg_op_test.py"], data = ["//tensorflow/core:image_testdata"], diff --git a/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py b/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py index a66d8d8a9a2a13..f7641c63e7f7e3 100644 --- a/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py +++ b/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py @@ -135,7 +135,7 @@ def testDrawBoundingBoxHalf(self): image, dtype=dtypes.half, colors=colors) # generate_bound_box_proposals is only available on GPU. - @test_util.run_gpu_only() + @test_util.run_gpu_only def testGenerateBoundingBoxProposals(self): # Op only exists on GPU. with self.cached_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 82dcd51214819e..936e4204d3ace8 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -256,7 +256,7 @@ cuda_py_strict_test( name = "linear_operator_circulant_test", size = "medium", srcs = ["linear_operator_circulant_test.py"], - shard_count = 32, + shard_count = 50, tags = [ "no_cuda11", # TODO(b/197522782): reenable test after fixing. "optonly", # times out, b/79171797 @@ -412,7 +412,7 @@ cuda_py_strict_test( name = "linear_operator_low_rank_update_test", size = "medium", srcs = ["linear_operator_low_rank_update_test.py"], - shard_count = 10, + shard_count = 15, tags = ["optonly"], deps = [ "//tensorflow/python/framework:config", @@ -516,12 +516,14 @@ cuda_py_strict_test( name = "linear_operator_tridiag_test", size = "medium", srcs = ["linear_operator_tridiag_test.py"], - shard_count = 5, + shard_count = 10, tags = [ "no_windows_gpu", "optonly", ], - xla_enable_strict_auto_jit = True, + # TODO(b/313470344): XLA temporarily disabled due to empty shards on 3.12. + xla_enable_strict_auto_jit = False, + xla_enabled = False, deps = [ "//tensorflow/python/framework:config", "//tensorflow/python/framework:test_lib", @@ -881,7 +883,7 @@ cuda_py_strict_test( name = "tridiagonal_matmul_op_test", size = "medium", srcs = ["tridiagonal_matmul_op_test.py"], - shard_count = 10, + shard_count = 5, deps = [ "//tensorflow/python/client:session", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py index 3f37a3585101d1..f478a85cd63df8 100644 --- a/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py @@ -240,7 +240,7 @@ def Test(self): lambda x: linalg_ops.matrix_inverse(x, adjoint=True), dtype, shape)) - if not test_lib.is_built_with_rocm(): + if True: # not test_lib.is_built_with_rocm(): # TODO(rocm) : # re-enable this test when upstream issues are resolved # see commit msg for details diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index ee84171e67ca81..ddd6879ba020a0 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -286,7 +286,8 @@ def dtypes_to_test(): def optional_tests(): """List of optional test names to run.""" return [ - "operator_matmul_with_same_type", + # TODO: b/310008894 - Re-enable this optional test. + # "operator_matmul_with_same_type", "operator_solve_with_same_type", ] @@ -371,7 +372,8 @@ def setUp(self): def optional_tests(): """List of optional test names to run.""" return [ - "operator_matmul_with_same_type", + # TODO: b/310008894 - Re-enable this optional test. + # "operator_matmul_with_same_type", "operator_solve_with_same_type", ] @@ -445,7 +447,8 @@ def skip_these_tests(): def optional_tests(): """List of optional test names to run.""" return [ - "operator_matmul_with_same_type", + # TODO: b/310008894 - Re-enable this optional test. + # "operator_matmul_with_same_type", "operator_solve_with_same_type", ] @@ -649,7 +652,8 @@ def operator_shapes_infos(): def optional_tests(): """List of optional test names to run.""" return [ - "operator_matmul_with_same_type", + # TODO: b/310008894 - Re-enable this optional test. + # "operator_matmul_with_same_type", "operator_solve_with_same_type", ] diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py index ca5cb6e0d1f7a9..bcb1360ea6b2eb 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py @@ -33,6 +33,13 @@ class LinearOperatorInversionTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" + # TODO: b/311343496 - Re-enable this test. + @staticmethod + def skip_these_tests() -> list[str]: + return [ + "test_saved_model", + ] + def tearDown(self): config.enable_tensor_float_32_execution(self.tf32_keep_) diff --git a/tensorflow/python/kernel_tests/linalg/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/linalg/matrix_square_root_op_test.py index 73d9d9263e9262..28511c609a9d8f 100644 --- a/tensorflow/python/kernel_tests/linalg/matrix_square_root_op_test.py +++ b/tensorflow/python/kernel_tests/linalg/matrix_square_root_op_test.py @@ -25,7 +25,6 @@ from tensorflow.python.platform import test -@test_util.run_all_without_tensor_float_32 class SquareRootOpTest(test.TestCase): def _verifySquareRoot(self, matrix, np_type): @@ -65,16 +64,19 @@ def _testMatrices(self, matrix1, matrix2): self._verifySquareRootComplex(matrix2) self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2)) + @test_util.run_without_tensor_float_32 def testSymmetricPositiveDefinite(self): matrix1 = np.array([[2., 1.], [1., 2.]]) matrix2 = np.array([[3., -1.], [-1., 3.]]) self._testMatrices(matrix1, matrix2) + @test_util.run_without_tensor_float_32 def testAsymmetric(self): matrix1 = np.array([[0., 4.], [-1., 5.]]) matrix2 = np.array([[33., 24.], [48., 57.]]) self._testMatrices(matrix1, matrix2) + @test_util.run_without_tensor_float_32 def testIdentityMatrix(self): # 2x2 identity = np.array([[1., 0], [0, 1.]]) @@ -83,11 +85,13 @@ def testIdentityMatrix(self): identity = np.array([[1., 0, 0], [0, 1., 0], [0, 0, 1.]]) self._verifySquareRootReal(identity) + @test_util.run_without_tensor_float_32 def testEmpty(self): self._verifySquareRootReal(np.empty([0, 2, 2])) self._verifySquareRootReal(np.empty([2, 0, 0])) @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32 def testWrongDimensions(self): # The input to the square root should be at least a 2-dimensional tensor. tensor = constant_op.constant([1., 2.]) @@ -95,12 +99,14 @@ def testWrongDimensions(self): gen_linalg_ops.matrix_square_root(tensor) @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32 def testNotSquare(self): with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]]) self.evaluate(gen_linalg_ops.matrix_square_root(tensor)) @test_util.run_in_graph_and_eager_modes(use_gpu=True) + @test_util.run_without_tensor_float_32 def testConcurrentExecutesWithoutError(self): matrix_shape = [5, 5] seed = [42, 24] diff --git a/tensorflow/python/kernel_tests/linalg/sparse/BUILD b/tensorflow/python/kernel_tests/linalg/sparse/BUILD index 9463e04bf8f0bc..d4d8d65195db00 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/BUILD +++ b/tensorflow/python/kernel_tests/linalg/sparse/BUILD @@ -85,7 +85,7 @@ cuda_py_strict_test( size = "medium", srcs = ["csr_sparse_matrix_grad_test.py"], main = "csr_sparse_matrix_grad_test.py", - shard_count = 50, + shard_count = 3, deps = [ "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", @@ -126,7 +126,7 @@ cuda_py_strict_test( size = "medium", srcs = ["csr_sparse_matrix_sparse_mat_mul_grad_test.py"], main = "csr_sparse_matrix_sparse_mat_mul_grad_test.py", - shard_count = 50, + shard_count = 10, deps = [ "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", diff --git a/tensorflow/python/kernel_tests/math_ops/BUILD b/tensorflow/python/kernel_tests/math_ops/BUILD index 0903a13db22c46..237e39435036ef 100644 --- a/tensorflow/python/kernel_tests/math_ops/BUILD +++ b/tensorflow/python/kernel_tests/math_ops/BUILD @@ -1,6 +1,10 @@ # Tests of TensorFlow math kernels written using the Python API. load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") +load( + "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -283,7 +287,7 @@ cuda_py_strict_test( name = "cwise_ops_unary_test", size = "medium", srcs = ["cwise_ops_unary_test.py"], - shard_count = 50, + shard_count = 10, tags = [ "no_windows", # TODO(b/207048097): re-enable ], @@ -369,7 +373,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "reduce_benchmark_test", srcs = ["reduce_benchmark_test.py"], deps = [ diff --git a/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py index 8a1d14be8417ee..b0cf1d0058c1a5 100644 --- a/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py @@ -883,7 +883,10 @@ def testPowNegativeExponentGpu(self): z = math_ops.pow(x, y) self.assertAllEqual(self.evaluate(z), [0, 1, 1, 1, -1]) - def testFloorModInfDenominator(self): + @test.disable_with_predicate( + pred=test.is_built_with_rocm, skip_message="On ROCm this test fails" + ) + def testFloorModfInfDenominator(self): """Regression test for GitHub issue #58369.""" if not test_util.is_gpu_available(): self.skipTest("Requires GPU") diff --git a/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py index 29daaea0b1643a..24c06cedce2443 100644 --- a/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py +++ b/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py @@ -445,8 +445,6 @@ def f(x): self._compareBoth(x, compute_f32(np.vectorize(math.erfc)), math_ops.erfc) self._compareBoth(x, compute_f32(np.square), math_ops.square) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testInt8Basic(self): x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8) self._compareCpu(x, np.abs, math_ops.abs) @@ -455,14 +453,10 @@ def testInt8Basic(self): self._compareBoth(x, np.negative, _NEG) self._compareBoth(x, np.sign, math_ops.sign) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt8Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint8) self._compareBoth(x, np.square, math_ops.square) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testInt16Basic(self): x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16) self._compareCpu(x, np.abs, math_ops.abs) @@ -471,8 +465,6 @@ def testInt16Basic(self): self._compareBoth(x, np.negative, _NEG) self._compareBoth(x, np.sign, math_ops.sign) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt16Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint16) self._compareBoth(x, np.square, math_ops.square) @@ -491,8 +483,6 @@ def testInt32Basic(self): self._compareBothSparse(x, np.square, math_ops.square) self._compareBothSparse(x, np.sign, math_ops.sign) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt32Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint32) self._compareBoth(x, np.square, math_ops.square) @@ -514,8 +504,6 @@ def testInt64Square(self): self._compareCpu(x, np.square, math_ops.square) self._compareBothSparse(x, np.square, math_ops.square) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt64Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint64) self._compareBoth(x, np.square, math_ops.square) diff --git a/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py b/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py index fbd5f9501c0933..3c166b86fabc74 100644 --- a/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py +++ b/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py @@ -89,9 +89,6 @@ def testUnsortedOps(self): result = op(data, segment_ids, num_segments) self.evaluate(result) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message="No ROCm support for complex types in segment reduction ops") @test_util.run_cuda_only def testUnsortedOpsComplex(self): for op in [ diff --git a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py index f22ddc973cd045..71c2f3a208f122 100644 --- a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py @@ -207,11 +207,11 @@ def _DtypesToTest(self, use_gpu): if use_gpu: # It is important that float32 comes first, since we are using its # gradients as a reference for fp16 gradients. - out = [dtypes.float32] + out = [dtypes.float32, dtypes.bfloat16] if test_util.GpuSupportsHalfMatMulAndConv(): out.append(dtypes.float16) if not test.is_built_with_rocm(): - out.extend([dtypes.float64, dtypes.bfloat16]) + out.extend([dtypes.float64]) return out return [dtypes.float32, dtypes.float64, dtypes.float16, dtypes.bfloat16] @@ -460,7 +460,7 @@ def _VerifyDilatedConvValuesParameters( op_name, rtol=1e-4, ): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): self.skipTest("GPU not available") expected_results = [] computed_results = [] @@ -520,7 +520,7 @@ def _VerifyValues(self, gpu_only=False, test_grappler_layout_optimizer=False, tol=1e-5): - if gpu_only and not test.is_gpu_available(cuda_only=True): + if gpu_only and not test.is_gpu_available(): return tensors = [] dilations = list(dilations) @@ -577,7 +577,7 @@ def _VerifyValuesParameters( test_grappler_layout_optimizer=False, tol=1e-5, ): - if (gpu_only and not use_gpu) or not test.is_gpu_available(cuda_only=True): + if (gpu_only and not use_gpu) or not test.is_gpu_available(): self.skipTest("GPU not available") if ( test_grappler_layout_optimizer or data_format != "NHWC" @@ -1330,8 +1330,12 @@ def MakeConv2d(inputs, filters): results[0], results[1], atol=tol_to_use, rtol=tol_to_use) @test_util.run_in_graph_and_eager_modes + @test.disable_with_predicate( + pred=test.is_built_with_rocm, + skip_message="MIOpen does not support group conv yet!", + ) def testConv2DGroupConvFwd(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): data_formats = ["NHWC", "NCHW"] else: data_formats = ["NHWC"] @@ -1347,7 +1351,11 @@ def testConv2DGroupConvFwd(self): dtype=dtypes.float32) @test_util.deprecated_graph_mode_only - @test_util.run_cuda_only + @test_util.run_gpu_only + @test.disable_with_predicate( + pred=test.is_built_with_rocm, + skip_message="MIOpen does not support group conv yet!", + ) def testInputGradientGroupConv(self): for data_format in ["NCHW", "NHWC"]: for test_input in [True, False]: @@ -1369,7 +1377,11 @@ def testInputGradientGroupConv(self): max_err=0.005) @test_util.deprecated_graph_mode_only - @test_util.run_cuda_only + @test_util.run_gpu_only + @test.disable_with_predicate( + pred=test.is_built_with_rocm, + skip_message="MIOpen does not support group conv yet!", + ) def testFilterGradientGroupConv(self): for data_format in ["NCHW", "NHWC"]: for test_input in [True, False]: @@ -1407,7 +1419,7 @@ def _RunAndVerifyBackpropInput(self, use_gpu, err, dilations=(1, 1)): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): return x1 = self._CreateNumpyTensor(filter_sizes) x2 = self._CreateNumpyTensor(output_sizes) @@ -1893,7 +1905,7 @@ def _RunAndVerifyBackpropFilterDilation(self, input_sizes, filter_sizes, @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 6, 1], @@ -1908,7 +1920,7 @@ def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 2, 3, 1], @@ -1923,7 +1935,7 @@ def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2DEmptyBackpropFilterDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 2, 3, 1], @@ -1938,7 +1950,7 @@ def testConv2DEmptyBackpropFilterDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 4, 3], @@ -1953,7 +1965,7 @@ def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): @test_util.deprecated_graph_mode_only def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 3, 1], @@ -1968,7 +1980,7 @@ def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 3, 6, 1], @@ -1983,7 +1995,7 @@ def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 2, 3, 1], @@ -1998,7 +2010,7 @@ def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2DEmptyBackpropInputDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[0, 2, 3, 1], @@ -2013,7 +2025,7 @@ def testConv2DEmptyBackpropInputDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): # The GPU version of this test is not very stable. So adjusting the # error threshold to 1e-4. @@ -2030,7 +2042,7 @@ def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): @test_util.deprecated_graph_mode_only def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 3, 3, 1], @@ -2053,7 +2065,7 @@ def _RunAndVerifyBackpropInputExplicitPadding(self, use_gpu, dilations=(1, 1), err=2e-5): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): return if not use_gpu and dilations != (1, 1): return # Non-default dilations is currently not supported on the CPU. @@ -2215,7 +2227,7 @@ def _RunAndVerifyBackpropFilterExplicitPadding(self, use_gpu, dilations=(1, 1), err=1e-5): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): return if not use_gpu and dilations != (1, 1): return # Non-default dilations is currently not supported on the CPU. @@ -3513,7 +3525,6 @@ def testConv2D3x3FilterStride1x1Valid(self): def testConv2D3x3FilterStride1x1Same(self): self._RunTestCases([1, 1], "SAME") - class Conv2DBenchmark(test.Benchmark): def benchmarkGPUConvStackFirst(self): diff --git a/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py b/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py index a9f63ad6ce9a94..4e466a0a1c876a 100644 --- a/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py +++ b/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py @@ -407,7 +407,7 @@ def _VerifyValues(self, interface_result, np_result, atol=tolerance, rtol=tolerance) @test_util.run_v1_only("b/120545219") - @test_util.run_cuda_only + @test_util.run_gpu_only def testDepthwiseConv2DCudnn(self): for index, (input_size, filter_size, _, stride, padding, dilations) in enumerate(ConfigsToTest()): @@ -510,10 +510,10 @@ def testDepthwiseConv2DExplicit(self): "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: " "%s", index, input_size, filter_size, stride, padding) # double datatype is currently not supported for convolution ops - # on the ROCm platform and its support for bfloat16 is unknown. - data_types = [dtypes.float16, dtypes.float32] + # on the ROCm platform + data_types = [dtypes.float16, dtypes.float32, dtypes.bfloat16] if not test.is_built_with_rocm(): - data_types.extend([dtypes.float64, dtypes.bfloat16]) + data_types.extend([dtypes.float64]) data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"] for data_type in data_types: for data_format in data_formats: @@ -529,8 +529,7 @@ def testDepthwiseConv2DExplicit(self): dilations=dilations, tolerance=tolerance) - -# This is testing against hand calculated results. + # This is testing against hand calculated results. def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding, expected, use_gpu): @@ -736,7 +735,7 @@ def _ConstructAndTestGradient(self, self.assertLess(err, tolerance) @test_util.run_v1_only("b/120545219") - @test_util.run_cuda_only + @test_util.run_gpu_only def testDepthwiseConv2DInputGradCudnn(self): for index, (input_size, filter_size, output_size, stride, padding, dilations) in enumerate(CheckGradConfigsToTest()): @@ -832,10 +831,10 @@ def testDepthwiseConv2DInputGradExplicit(self): "stride: %d, padding: %s", index, input_size, filter_size, stride, padding) # double datatype is currently not supported for convolution ops - # on the ROCm platform and its support for bfloat16 is unknown. - data_types = [dtypes.float16, dtypes.float32] + # on the ROCm platform + data_types = [dtypes.float16, dtypes.float32, dtypes.bfloat16] if not test.is_built_with_rocm(): - data_types.extend([dtypes.float64, dtypes.bfloat16]) + data_types.extend([dtypes.float64]) data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"] for data_type in data_types: for data_format in data_formats: @@ -852,7 +851,7 @@ def testDepthwiseConv2DInputGradExplicit(self): dilations=dilations) @test_util.run_v1_only("b/120545219") - @test_util.run_cuda_only + @test_util.run_gpu_only def testDepthwiseConv2DFilterGradCudnn(self): for index, (input_size, filter_size, output_size, stride, padding, dilations) in enumerate(CheckGradConfigsToTest()): @@ -945,10 +944,10 @@ def testDepthwiseConv2DFilterGradExplicit(self): "stride: %d, padding: %s", index, input_size, filter_size, stride, padding) # double datatype is currently not supported for convolution ops - # on the ROCm platform and its support for bfloat16 is unknown. - data_types = [dtypes.float16, dtypes.float32] + # on the ROCm platform + data_types = [dtypes.float16, dtypes.float32, dtypes.bfloat16] if not test.is_built_with_rocm(): - data_types.extend([dtypes.float64, dtypes.bfloat16]) + data_types.extend([dtypes.float64]) data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"] for data_type in data_types: for data_format in data_formats: @@ -999,14 +998,14 @@ def testDepthwiseConv2DInputGradCompare(self): padding) self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropInput(input_size, filter_size, output_size, stride, - padding, "float64") self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + # So, we skip these tests. + if not test.is_built_with_rocm(): + self._CompareBackpropInput( + input_size, filter_size, output_size, stride, padding, "float64" + ) @test_util.run_gpu_only def testDepthwiseConv2DInputGradExplicitCompare(self): @@ -1020,14 +1019,13 @@ def testDepthwiseConv2DInputGradExplicitCompare(self): padding) self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropInput(input_size, filter_size, output_size, stride, - padding, "float64") self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareBackpropInput( + input_size, filter_size, output_size, stride, padding, "float64" + ) def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, stride, padding, dtype): @@ -1080,15 +1078,13 @@ def testDepthwiseConv2DFilterGradCompare(self): padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, - padding, "float64") - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareBackpropFilter( + input_size, filter_size, output_size, stride, padding, "float64" + ) @test_util.run_gpu_only def testDepthwiseConv2DFilterGradExplicitCompare(self): @@ -1102,15 +1098,13 @@ def testDepthwiseConv2DFilterGradExplicitCompare(self): padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, - padding, "float64") - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareBackpropFilter( + input_size, filter_size, output_size, stride, padding, "float64" + ) def _CompareForward(self, input_sizes, filter_sizes, output_sizes, stride, padding, dtype): @@ -1146,16 +1140,15 @@ def testDepthwiseConv2DForwardCompare(self): padding) self._CompareForward(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareForward(input_size, filter_size, output_size, stride, - padding, "float64") - self._CompareForward(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareForward( + input_size, filter_size, output_size, stride, padding, "float64" + ) + @test_util.run_gpu_only def testDepthwiseConv2DForwardExplicitCompare(self): for index, (input_size, filter_size, output_size, stride, padding, @@ -1166,14 +1159,14 @@ def testDepthwiseConv2DForwardExplicitCompare(self): "Testing DepthwiseConv2DForwardCompare, %dth config: %r * %r, " "stride: %d, padding: %s", index, input_size, filter_size, stride, padding) - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareForward(input_size, filter_size, output_size, stride, - padding, "float64") + self._CompareForward(input_size, filter_size, output_size, stride, padding, "float32") - self._CompareForward(input_size, filter_size, output_size, stride, padding, "bfloat16") + + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareForward( + input_size, filter_size, output_size, stride, padding, "float64" + ) diff --git a/tensorflow/python/kernel_tests/nn_ops/losses_test.py b/tensorflow/python/kernel_tests/nn_ops/losses_test.py index 7da91f686a849c..b339738b485800 100644 --- a/tensorflow/python/kernel_tests/nn_ops/losses_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/losses_test.py @@ -101,7 +101,7 @@ def testLossWithSampleSpecificWeightsAllZero(self): with self.cached_session(): self.assertAlmostEqual(0.0, self.evaluate(loss), 3) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerNoMemoryLeaked(self): # This is a somewhat convoluted way of testing that nothing gets added to # a global collection. @@ -244,7 +244,7 @@ def testAllCorrectInt32Labels(self): self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerNoMemoryLeaked(self): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) diff --git a/tensorflow/python/kernel_tests/nn_ops/rnn_test.py b/tensorflow/python/kernel_tests/nn_ops/rnn_test.py index e517f4ecc8864c..f13a2521d44516 100644 --- a/tensorflow/python/kernel_tests/nn_ops/rnn_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/rnn_test.py @@ -240,7 +240,7 @@ def testUnbalancedOutputIsAccepted(self): self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) self.assertAllEqual(4, state) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerMemory(self): with context.eager_mode(): cell = TensorArrayStateRNNCell() diff --git a/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py b/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py index a30d82591da5c9..381e5c093f007e 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py +++ b/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py @@ -71,7 +71,7 @@ def testSingleClass(self): self.assertAllClose([0.0, 0.0, 0.0], tf_loss) self.assertAllClose([[0.0], [0.0], [0.0]], tf_gradient) - @test_util.run_gpu_only() + @test_util.run_gpu_only def _testInvalidLabelGPU(self, invalid_label_gradient=np.nan): labels = [4, 3, 0, -1] logits = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.], diff --git a/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py index ae41fc42fb0260..cdccc3da78c9ea 100644 --- a/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py +++ b/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py @@ -996,7 +996,7 @@ def testNoMemoryLeak_graphMode(self): with context.graph_mode(), ops.Graph().as_default(): summary_ops.create_file_writer_v2(logdir) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNoMemoryLeak_eagerMode(self): logdir = self.get_temp_dir() with summary_ops.create_file_writer_v2(logdir).as_default(): @@ -1495,12 +1495,13 @@ def f(): assert context.executing_eagerly() logdir = self.get_temp_dir() writer = summary_ops.create_file_writer_v2(logdir) - summary_ops.trace_on(graph=True, profiler=True) profiler_outdir = self.get_temp_dir() + summary_ops.trace_on( + graph=True, profiler=True, profiler_outdir=profiler_outdir + ) with writer.as_default(): f() - summary_ops.trace_export( - name='foo', step=1, profiler_outdir=profiler_outdir) + summary_ops.trace_export(name='foo', step=1) writer.close() @test_util.run_v2_only diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py index 64972f3f850768..bcfa84b14d1507 100644 --- a/tensorflow/python/module/module_test.py +++ b/tensorflow/python/module/module_test.py @@ -17,6 +17,8 @@ import abc import collections import itertools +import sys +import unittest from absl.testing import parameterized @@ -514,6 +516,8 @@ class DangerousModule(module.Module): self.assertLen(mod.variables, 1) self.assertEqual(mod.variables[0], mod.normal_variable) + @unittest.skipIf(sys.version_info.major == 3 and sys.version_info.minor == 12, + reason="b/313658911: _TupleWrapper __dict__ attribute error") def test_with_path(self): mod = module.Module() mod.w = variables.Variable(1.) @@ -531,6 +535,8 @@ def test_with_path(self): ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"], ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},) + @unittest.skipIf(sys.version_info.major == 3 and sys.version_info.minor == 12, + reason="b/313658911: _TupleWrapper __dict__ attribute error") def test_cycles_with_path(self): mod = module.Module() mod.w = variables.Variable(1.) diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 5f86568227670c..793823905688ce 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -31,6 +31,10 @@ from tensorflow.core.protobuf.config_pb2 import * from tensorflow.core.util.event_pb2 import * +# Checkpoint Sharding +from tensorflow.python.checkpoint.sharding import sharding_util +from tensorflow.python.checkpoint.sharding import sharding_policies + # Compat from tensorflow.python.compat import v2_compat @@ -117,6 +121,7 @@ from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import sets from tensorflow.python.ops import stateful_random_ops +from tensorflow.python.ops import tensor_getitem_override from tensorflow.python.ops import while_v2 from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.linalg.sparse import sparse @@ -170,6 +175,7 @@ # Summary from tensorflow.python.summary import summary +from tensorflow.python.summary import tb_summary # TPU from tensorflow.python.tpu import api diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index af1057af5a975c..a615f94cc04835 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -2,6 +2,10 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") load("//tensorflow/core/platform:build_config_root.bzl", "tf_additional_xla_deps_py") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_strict_wrapper_private_py") +load( + "//tensorflow/tools/test:performance.bzl", + "cuda_py_benchmark_test", +) visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", @@ -307,6 +311,7 @@ tf_gen_op_strict_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/compiler/tests:__pkg__", + "//tensorflow/dtensor/python/tests:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/kernel_tests/image_ops:__pkg__", "//tensorflow/python/ops/parallel_for:__pkg__", @@ -482,6 +487,7 @@ tf_gen_op_strict_wrapper_private_py( name = "parsing_ops_gen", visibility = [ "//learning/brain/python/ops:__pkg__", + "//tensorflow/dtensor/python/tests:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/autograph/operators:__pkg__", "//tensorflow/python/data/ops:__pkg__", @@ -514,6 +520,7 @@ tf_gen_op_strict_wrapper_private_py( visibility = [ "//learning/brain/python/ops:__pkg__", "//tensorflow/compiler/tests:__pkg__", + "//tensorflow/dtensor/python/tests:__pkg__", "//tensorflow/python:__pkg__", "//tensorflow/python/kernel_tests/random:__pkg__", ], @@ -775,7 +782,10 @@ cuda_py_strict_test( py_strict_library( name = "array_ops", - srcs = ["array_ops.py"], + srcs = [ + "array_ops.py", + "tensor_getitem_override.py", + ], srcs_version = "PY3", visibility = visibility, deps = [ @@ -1481,7 +1491,6 @@ py_strict_library( ":ctc_ops_gen", ":custom_gradient", ":functional_ops", - # TODO(b/280454072) Remove inplace_ops and compat when forward compatibility window expires. ":inplace_ops", ":linalg_ops", ":map_fn", @@ -1489,7 +1498,6 @@ py_strict_library( ":nn_grad", ":nn_ops", ":sparse_ops", - "//tensorflow/python/compat", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", @@ -1629,6 +1637,7 @@ py_strict_library( ":cudnn_rnn_grad", ":gradients_util", ":image_grad", + ":io_ops", ":linalg_grad", ":linalg_ops", ":logging_ops", @@ -2052,7 +2061,10 @@ py_strict_library( py_strict_library( name = "math_ops", - srcs = ["math_ops.py"], + srcs = [ + "math_ops.py", + "tensor_math_operator_overrides.py", + ], srcs_version = "PY3", deps = [ ":array_ops", @@ -2068,20 +2080,20 @@ py_strict_library( "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:override_binary_operator", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/framework:tensor_util", - "//tensorflow/python/ops/numpy_ops:np_dtypes", "//tensorflow/python/platform:tf_logging", + "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:nest", "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", - "//tensorflow/python/util:traceback_utils", "//third_party/py/numpy", ], ) @@ -2101,49 +2113,9 @@ py_strict_library( ], ) -py_strict_library( +alias( name = "resource_variable_ops", - srcs = ["resource_variable_ops.py"], - srcs_version = "PY3", - deps = [ - ":array_ops", - ":array_ops_gen", - ":handle_data_util", - ":math_ops", - ":resource_variable_ops_gen", - ":state_ops", - ":state_ops_gen", - ":variables", - "//tensorflow/core:protos_all_py", - "//tensorflow/core/function/trace_type", - "//tensorflow/python/checkpoint:tensor_callable", - "//tensorflow/python/client:pywrap_tf_session", - "//tensorflow/python/compat", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:record", - "//tensorflow/python/eager:tape", - "//tensorflow/python/framework:auto_control_deps_utils", - "//tensorflow/python/framework:composite_tensor", - "//tensorflow/python/framework:composite_tensor_gradient", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:cpp_shape_inference_proto_py", - "//tensorflow/python/framework:device", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:indexed_slices", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor", - "//tensorflow/python/framework:tensor_conversion_registry", - "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/saved_model:nested_structure_coder", - "//tensorflow/python/trackable:base", - "//tensorflow/python/types:core", - "//tensorflow/python/util:_pywrap_utils", - "//tensorflow/python/util:compat", - "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:tf_export", - "//third_party/py/numpy", - ], + actual = ":variables", ) py_strict_library( @@ -2998,7 +2970,6 @@ py_strict_library( "//tensorflow/dtensor/python:api", "//tensorflow/dtensor/python:layout", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:profiler", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", @@ -3006,6 +2977,7 @@ py_strict_library( "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_util", "//tensorflow/python/platform:tf_logging", + "//tensorflow/python/profiler:profiler_v2", "//tensorflow/python/trackable:resource", "//tensorflow/python/training:training_util", "//tensorflow/python/util:deprecation", @@ -3134,46 +3106,78 @@ py_strict_library( py_strict_library( name = "variables", - srcs = ["variables.py"], + srcs = [ + "resource_variable_ops.py", + "variables.py", + ], srcs_version = "PY3", deps = [ ":array_ops", + ":array_ops_gen", ":array_ops_stack", ":control_flow_ops", + ":handle_data_util", ":math_ops", ":math_ops_gen", + ":resource_variable_ops_gen", ":state_ops", + ":state_ops_gen", "//tensorflow/core:protos_all_py", + "//tensorflow/core/function/trace_type", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/checkpoint:tensor_callable", + "//tensorflow/python/client:pywrap_tf_session", + "//tensorflow/python/compat", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:record", + "//tensorflow/python/eager:tape", + "//tensorflow/python/framework:auto_control_deps_utils", + "//tensorflow/python/framework:composite_tensor", + "//tensorflow/python/framework:composite_tensor_gradient", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:cpp_shape_inference_proto_py", + "//tensorflow/python/framework:device", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_conversion_registry", "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/trackable:base", - "//tensorflow/python/util:_pywrap_utils", + "//tensorflow/python/types:core", + "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:object_identity", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_should_use", "//tensorflow/python/util:traceback_utils", + "//third_party/py/numpy", ], ) -py_strict_library( +alias( name = "ref_variable", - srcs = ["ref_variable.py"], + actual = ":variable_v1", +) + +py_strict_library( + name = "variable_v1", + srcs = [ + "ref_variable.py", + "variable_v1.py", + ], srcs_version = "PY3", deps = [ ":array_ops", ":array_ops_gen", + ":cond", ":resource_variable_ops", ":resource_variables_toggle", ":state_ops", ":state_ops_gen", ":variable_scope", - ":variable_v1", ":variables", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", @@ -3187,19 +3191,6 @@ py_strict_library( "//tensorflow/python/types:core", "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", - ], -) - -py_strict_library( - name = "variable_v1", - srcs = ["variable_v1.py"], - srcs_version = "PY3", - deps = [ - ":cond", - ":state_ops", - ":variable_scope", - ":variables", - "//tensorflow/python/framework:ops", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_should_use", ], @@ -3735,7 +3726,7 @@ cuda_py_strict_test( main = "nn_fused_batchnorm_test.py", python_version = "PY3", shard_count = 24, - tags = ["no_rocm"], + tags = [], deps = [ ":array_ops", ":gradient_checker", @@ -4033,7 +4024,7 @@ py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "accumulate_n_benchmark", size = "medium", srcs = ["accumulate_n_benchmark.py"], @@ -4055,7 +4046,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "batch_norm_benchmark", srcs = ["batch_norm_benchmark.py"], main = "batch_norm_benchmark.py", @@ -4077,7 +4068,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "collective_ops_benchmark", srcs = ["collective_ops_benchmark.py"], main = "collective_ops_benchmark.py", @@ -4092,7 +4083,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "concat_benchmark", srcs = ["concat_benchmark.py"], main = "concat_benchmark.py", @@ -4109,7 +4100,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "control_flow_ops_benchmark", srcs = ["control_flow_ops_benchmark.py"], main = "control_flow_ops_benchmark.py", @@ -4129,7 +4120,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "conv2d_benchmark", size = "large", srcs = ["conv2d_benchmark.py"], @@ -4150,7 +4141,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "split_benchmark", srcs = ["split_benchmark.py"], main = "split_benchmark.py", @@ -4169,7 +4160,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "transpose_benchmark", size = "medium", srcs = ["transpose_benchmark.py"], @@ -4187,7 +4178,7 @@ cuda_py_strict_test( ], ) -cuda_py_strict_test( +cuda_py_benchmark_test( name = "matmul_benchmark", size = "medium", srcs = ["matmul_benchmark.py"], diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ab2eedaaadb589..437e504114ffc6 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import common_shapes from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op +from tensorflow.python.framework import constant_tensor_conversion # pylint: disable=unused-import from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import indexed_slices @@ -40,6 +41,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import shape_util +from tensorflow.python.ops import tensor_getitem_override # pylint: disable=unused-import # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_array_ops import * @@ -57,10 +59,6 @@ newaxis = None tf_export("newaxis").export_constant(__name__, "newaxis") -# We override the 'slice' for the "slice" op, so we keep Python's -# existing 'slice' for later use in this module. -_BaseSlice = slice - @tf_export("reshape", v1=["reshape", "manip.reshape"]) @dispatch.add_dispatch_support @@ -936,237 +934,6 @@ def rank_internal(input, name=None, optimize=True): return gen_array_ops.rank(input, name=name) -_SLICE_TYPE_ERROR = ( - "Only integers, slices (`:`), ellipsis (`...`), " - "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid " - "indices") - -_SUPPORTED_SLICE_DTYPES = (dtypes.int16, dtypes.int32, dtypes.int32_ref, - dtypes.int64, dtypes.int64_ref) - - -def _check_index(idx): - """Check if a given value is a valid index into a tensor.""" - if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): - return - - # Optimistic check. Assumptions: - # * any object with a dtype is supported - # * any object with a dtype has a sizeable shape attribute. - dtype = getattr(idx, "dtype", None) - if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or - idx.shape and len(idx.shape) == 1): - # TODO(slebedev): IndexError seems more appropriate here, but it - # will break `_slice_helper` contract. - raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx)) - - -def _is_undefined_dimension(d): - return isinstance(d, tensor_shape.Dimension) and d.value is None - - -@tf_export("__operators__.getitem", v1=[]) -@dispatch.add_dispatch_support -def _slice_helper(tensor, slice_spec, var=None): - """Overload for Tensor.__getitem__. - - This operation extracts the specified region from the tensor. - The notation is similar to NumPy with the restriction that - currently only support basic indexing. That means that - using a non-scalar tensor as input is not currently allowed. - - Some useful examples: - - ```python - # Strip leading and trailing 2 elements - foo = tf.constant([1,2,3,4,5,6]) - print(foo[2:-2]) # => [3,4] - - # Skip every other row and reverse the order of the columns - foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) - print(foo[::2,::-1]) # => [[3,2,1], [9,8,7]] - - # Use scalar tensors as indices on both dimensions - print(foo[tf.constant(0), tf.constant(2)]) # => 3 - - # Insert another dimension - foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) - print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]] - print(foo[:, tf.newaxis, :]) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]] - print(foo[:, :, tf.newaxis]) # => [[[1],[2],[3]], [[4],[5],[6]], - [[7],[8],[9]]] - - # Ellipses (3 equivalent operations) - foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) - print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]] - print(foo[tf.newaxis, ...]) # => [[[1,2,3], [4,5,6], [7,8,9]]] - print(foo[tf.newaxis]) # => [[[1,2,3], [4,5,6], [7,8,9]]] - - # Masks - foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) - print(foo[foo > 2]) # => [3, 4, 5, 6, 7, 8, 9] - ``` - - Notes: - - `tf.newaxis` is `None` as in NumPy. - - An implicit ellipsis is placed at the end of the `slice_spec` - - NumPy advanced indexing is currently not supported. - - Purpose in the API: - - This method is exposed in TensorFlow's API so that library developers - can register dispatching for `Tensor.__getitem__` to allow it to handle - custom composite tensors & other custom objects. - - The API symbol is not intended to be called by users directly and does - appear in TensorFlow's generated documentation. - - Args: - tensor: An tensor.Tensor object. - slice_spec: The arguments to Tensor.__getitem__. - var: In the case of variable slice assignment, the Variable object to slice - (i.e. tensor is the read-only view of this variable). - - Returns: - The appropriate slice of "tensor", based on "slice_spec". - - Raises: - ValueError: If a slice range is negative size. - TypeError: If the slice indices aren't int, slice, ellipsis, - tf.newaxis or scalar int32/int64 tensors. - """ - tensor = ops.convert_to_tensor(tensor) - # TODO(wangpeng): Consider supporting var - if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access - return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access - - if (isinstance(slice_spec, bool) - or (isinstance(slice_spec, tensor_lib.Tensor) - and slice_spec.dtype == dtypes.bool) - or (isinstance(slice_spec, np.ndarray) - and slice_spec.dtype == bool)): - return boolean_mask(tensor=tensor, mask=slice_spec) - - if not isinstance(slice_spec, (list, tuple)): - slice_spec = [slice_spec] - - begin, end, strides = [], [], [] - index = 0 - - new_axis_mask, shrink_axis_mask = 0, 0 - begin_mask, end_mask = 0, 0 - ellipsis_mask = 0 - for s in slice_spec: - if isinstance(s, _BaseSlice): - # Finds the best dtype for begin, end, and strides. - dtype = None - for t in [s.start, s.stop, s.step]: - if t is None or not isinstance(t, tensor_lib.Tensor): - continue - if t.dtype == dtypes.int64: - dtype = dtypes.int64 - elif t.dtype == dtypes.int32 and dtype != dtypes.int64: - dtype = dtypes.int32 - elif t.dtype == dtypes.int16 and dtype is None: - dtype = dtypes.int16 - - if s.start is not None and not _is_undefined_dimension(s.start): - _check_index(s.start) - begin.append(s.start) - else: - if dtype is not None: - begin.append(constant_op.constant(0, dtype=dtype)) - else: - begin.append(0) - begin_mask |= (1 << index) - if s.stop is not None and not _is_undefined_dimension(s.stop): - _check_index(s.stop) - end.append(s.stop) - else: - if dtype is not None: - end.append(constant_op.constant(0, dtype=dtype)) - else: - end.append(0) - end_mask |= (1 << index) - if s.step is not None and not _is_undefined_dimension(s.step): - _check_index(s.step) - strides.append(s.step) - else: - if dtype is not None: - strides.append(constant_op.constant(1, dtype=dtype)) - else: - strides.append(1) - elif s is Ellipsis: - begin.append(0) - end.append(0) - strides.append(1) - ellipsis_mask |= (1 << index) - elif s is newaxis: - begin.append(0) - end.append(0) - strides.append(1) - new_axis_mask |= (1 << index) - else: - _check_index(s) - begin.append(s) - end.append(s + 1) - # TODO(mdan): Investigate why we can't set int32 here. - if ( - isinstance(s, tensor_lib.Tensor) - and (s.dtype == dtypes.int16 or s.dtype == dtypes.int64)): - strides.append(constant_op.constant(1, dtype=s.dtype)) - else: - strides.append(1) - shrink_axis_mask |= (1 << index) - index += 1 - - # stack possibly involves no tensors, so we must use op_scope correct graph. - with ops.name_scope( - None, - "strided_slice", [tensor] + begin + end + strides, - skip_on_eager=False) as name: - if begin: - packed_begin, packed_end, packed_strides = ( - array_ops_stack.stack(begin), - array_ops_stack.stack(end), - array_ops_stack.stack(strides)) - # TODO(mdan): Instead of implicitly casting, it's better to enforce the - # same dtypes. - if (packed_begin.dtype == dtypes.int64 or - packed_end.dtype == dtypes.int64 or - packed_strides.dtype == dtypes.int64): - if packed_begin.dtype != dtypes.int64: - packed_begin = gen_math_ops.cast(packed_begin, dtypes.int64) - if packed_end.dtype != dtypes.int64: - packed_end = gen_math_ops.cast(packed_end, dtypes.int64) - if packed_strides.dtype != dtypes.int64: - packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64) - elif (packed_begin.dtype == dtypes.int16 and - packed_end.dtype == dtypes.int16 and - packed_strides.dtype == dtypes.int16): - if packed_begin.dtype != dtypes.int16: - packed_begin = gen_math_ops.cast(packed_begin, dtypes.int16) - if packed_end.dtype != dtypes.int16: - packed_end = gen_math_ops.cast(packed_end, dtypes.int16) - if packed_strides.dtype != dtypes.int16: - packed_strides = gen_math_ops.cast(packed_strides, dtypes.int16) - else: - var_empty = constant([], dtype=dtypes.int32) - packed_begin = packed_end = packed_strides = var_empty - return strided_slice( - tensor, - packed_begin, - packed_end, - packed_strides, - begin_mask=begin_mask, - end_mask=end_mask, - shrink_axis_mask=shrink_axis_mask, - new_axis_mask=new_axis_mask, - ellipsis_mask=ellipsis_mask, - var=var, - name=name) - - # pylint: disable=undefined-variable,protected-access,redefined-outer-name @tf_export("slice") @dispatch.add_dispatch_support @@ -1364,53 +1131,6 @@ def assign(val, name=None): return op -def _SliceHelperVar(var, slice_spec): - """Creates a slice helper object given a variable. - - This allows creating a sub-tensor from part of the current contents - of a variable. See `tf.Tensor.__getitem__` for detailed examples - of slicing. - - This function in addition also allows assignment to a sliced range. - This is similar to `__setitem__` functionality in Python. However, - the syntax is different so that the user can capture the assignment - operation for grouping or passing to `sess.run()` in TF1. - For example, - - ```python - import tensorflow as tf - A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) - print(A[:2, :2]) # => [[1,2], [4,5]] - - A[:2,:2].assign(22. * tf.ones((2, 2)))) - print(A) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] - ``` - - Note that assignments currently do not support NumPy broadcasting - semantics. - - Args: - var: An `ops.Variable` object. - slice_spec: The arguments to `Tensor.__getitem__`. - - Returns: - The appropriate slice of "tensor", based on "slice_spec". - As an operator. The operator also has a `assign()` method - that can be used to generate an assignment operator. - - Raises: - ValueError: If a slice range is negative size. - TypeError: TypeError: If the slice indices aren't int, slice, - ellipsis, tf.newaxis or int32/int64 tensors. - - """ - - return _slice_helper(var.value(), slice_spec, var) - - -tensor_lib.Tensor._override_operator("__getitem__", _slice_helper) - - @tf_export("parallel_stack") @dispatch.add_dispatch_support def parallel_stack(values, name="parallel_stack"): diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 27f57869511514..2f1247d4226642 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -207,8 +207,12 @@ def _is_op_stateful(op): Returns: Boolean indicates whether the operation is stateless or not. """ + # TODO(pineapplejuice233): Remove these hardcode op names once they can be marked as + # stateless in TF. if op.type == "GlobalIterId": return False + if op.type == "UpdateFdoWithGlobalMinibatchStatistics": + return False if op.type == "CollectiveGatherV2" and op.get_attr("is_stateless"): return False return op._is_stateful diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 1bde62c0adbc20..ba4590c385569b 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -16,9 +16,6 @@ import uuid -# TODO(b/280454072) Remove compat and inplace_ops when foward compatibility -# window expires. -from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -1497,17 +1494,10 @@ def body(i, num_elems, *args): new_out = [] else: update_i = i + 1 if inclusive and not reverse else i - # TODO(b/280454072) Cleanup when foward compatibility window expires. - if compat.forward_compatible(2023, 10, 26): - new_out = [ - gen_array_ops.tensor_scatter_update(x, [[update_i]], [y]) - for x, y in zip(out, flat_accum) - ] - else: - new_out = [ - inplace_ops.alias_inplace_update(x, update_i, y) - for x, y in zip(out, flat_accum) - ] + new_out = [ + gen_array_ops.tensor_scatter_update(x, [[update_i]], [y]) + for x, y in zip(out, flat_accum) + ] i = i - 1 if reverse else i + 1 return [i, num_elems] + new_out + flat_accum @@ -1522,15 +1512,9 @@ def body(i, num_elems, *args): [[num_outputs], array_ops.shape(initial_accum)], 0) out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) if inclusive: - # TODO(b/280454072) Cleanup when foward compatibility window expires. - if compat.forward_compatible(2023, 10, 26): - out = gen_array_ops.tensor_scatter_add( - out, [[init_i + (1 if reverse else 0)]], [initial_accum] - ) - else: - out = inplace_ops.alias_inplace_add( - out, init_i + (1 if reverse else 0), initial_accum - ) + out = gen_array_ops.tensor_scatter_add( + out, [[init_i + (1 if reverse else 0)]], [initial_accum] + ) outputs.append(out) loop_in = [init_i, num_elems] + outputs + flat_initial hostmem = [ diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index ae88a6d6306831..a45b9965078898 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import cudnn_rnn_grad # pylint: disable=unused-import from tensorflow.python.ops import gradients_util from tensorflow.python.ops import image_grad # pylint: disable=unused-import +from tensorflow.python.ops import io_ops # pylint: disable=unused-import from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import from tensorflow.python.ops import logging_ops # pylint: disable=unused-import diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index b643278e3f9eb2..44085948a48689 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -1254,6 +1254,9 @@ def Grad(*grad): @test_util.enable_quantized_dtypes_training def testCustomGradientQuantizedDtypeTraining(self): + # TODO(b/309175067): Remove below skipTest() when fixed. + if sys.platform == "darwin": + self.skipTest("This test fails in TF MacOS nightly and continuous builds") with context.eager_mode(): @custom_gradient.custom_gradient def F(x): diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index fd8995f6250a81..e476217347bed6 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2030,7 +2030,8 @@ def random_brightness(image, max_delta, seed=None): Args: image: An image or images to adjust. - max_delta: float, must be non-negative. + max_delta: float, must be non-negative. This parameter controls the maximum + relative change in brightness. seed: A Python integer. Used to create a random seed. See `tf.compat.v1.set_random_seed` for behavior. diff --git a/tensorflow/python/ops/init_ops_test.py b/tensorflow/python/ops/init_ops_test.py index a0ef239581405a..0d34a764e5f6fd 100644 --- a/tensorflow/python/ops/init_ops_test.py +++ b/tensorflow/python/ops/init_ops_test.py @@ -172,9 +172,6 @@ def test_Orthogonal(self): self._runner( init_ops.Orthogonal(seed=123), tensor_shape, target_mean=0.) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message='Disable subtest on ROCm due to missing QR op support') @test_util.run_gpu_only def testVariablePlacementWithOrthogonalInitializer(self): with ops.Graph().as_default() as g: diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 591d50d1c089a4..b9a4a32e425481 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -1656,7 +1656,10 @@ def _matmul( # pylint:disable=missing-docstring a_is_sparse=False, b_is_sparse=False, output_type=None, # pylint: disable=unused-argument - name=None): + grad_a=False, # pylint: disable=unused-argument + grad_b=False, # pylint: disable=unused-argument + name=None, +): if transpose_a or transpose_b: raise ValueError("Transposing not supported at this time.") if a_is_sparse or b_is_sparse: diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 3cdbe9ddba43aa..faa71ee0548b80 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -387,9 +387,9 @@ def test_log_abs_det(self: "LinearOperatorDerivedClassTest"): return test_log_abs_det -@test_util.run_without_tensor_float_32("Use FP32 in matmul") def _test_operator_matmul_with_same_type(use_placeholder, shapes_info, dtype): """op_a.matmul(op_b), in the case where the same type is returned.""" + @test_util.run_without_tensor_float_32("Use FP32 in matmul") def test_operator_matmul_with_same_type( self: "LinearOperatorDerivedClassTest"): with self.session(graph=ops.Graph()) as sess: @@ -501,7 +501,6 @@ def _test_matmul_base( self.assertAC(op_matmul_v, mat_matmul_v) -@test_util.run_without_tensor_float_32("Use FP32 in matmul") def _test_matmul( use_placeholder, shapes_info, @@ -509,6 +508,7 @@ def _test_matmul( adjoint, adjoint_arg, blockwise_arg): + @test_util.run_without_tensor_float_32("Use FP32 in matmul") def test_matmul(self: "LinearOperatorDerivedClassTest"): _test_matmul_base( self, @@ -522,7 +522,6 @@ def test_matmul(self: "LinearOperatorDerivedClassTest"): return test_matmul -@test_util.run_without_tensor_float_32("Use FP32 in matmul") def _test_matmul_with_broadcast( use_placeholder, shapes_info, @@ -530,6 +529,7 @@ def _test_matmul_with_broadcast( adjoint, adjoint_arg, blockwise_arg): + @test_util.run_without_tensor_float_32("Use FP32 in matmul") def test_matmul_with_broadcast(self: "LinearOperatorDerivedClassTest"): _test_matmul_base( self, @@ -822,8 +822,8 @@ def test_diag_part(self: "LinearOperatorDerivedClassTest"): return test_diag_part -@test_util.run_without_tensor_float_32("Use FP32 in matmul") def _test_composite_tensor(use_placeholder, shapes_info, dtype): + @test_util.run_without_tensor_float_32("Use FP32 in matmul") def test_composite_tensor(self: "LinearOperatorDerivedClassTest"): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED @@ -863,8 +863,8 @@ def body(op): return test_composite_tensor -@test_util.run_without_tensor_float_32("Use FP32 in matmul") def _test_saved_model(use_placeholder, shapes_info, dtype): + @test_util.run_without_tensor_float_32("Use FP32 in matmul") def test_saved_model(self: "LinearOperatorDerivedClassTest"): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 3a120565a1603b..8fe7047cbd42ef 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1664,13 +1664,15 @@ def _MatMulGradAgainstFirstOnly(op: ops.Operation, grad): t_b = op.get_attr("transpose_b") b = math_ops.conj(op.inputs[1]) if not t_a and not t_b: - grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) + grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True, grad_a=True) elif not t_a and t_b: - grad_a = gen_math_ops.mat_mul(grad, b) + grad_a = gen_math_ops.mat_mul(grad, b, grad_a=True) elif t_a and not t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True, grad_a=True) elif t_a and t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) + grad_a = gen_math_ops.mat_mul( + b, grad, transpose_a=True, transpose_b=True, grad_a=True + ) return grad_a, None @@ -1680,13 +1682,15 @@ def _MatMulGradAgainstSecondOnly(op: ops.Operation, grad): t_b = op.get_attr("transpose_b") a = math_ops.conj(op.inputs[0]) if not t_a and not t_b: - grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) + grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True, grad_b=True) elif not t_a and t_b: - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, grad_b=True) elif t_a and not t_b: - grad_b = gen_math_ops.mat_mul(a, grad) + grad_b = gen_math_ops.mat_mul(a, grad, grad_b=True) elif t_a and t_b: - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) + grad_b = gen_math_ops.mat_mul( + grad, a, transpose_a=True, transpose_b=True, grad_b=True + ) return None, grad_b @@ -1709,17 +1713,21 @@ def _MatMulGrad(op: ops.Operation, grad): a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) if not t_a and not t_b: - grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) - grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) + grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True, grad_a=True) + grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True, grad_b=True) elif not t_a and t_b: - grad_a = gen_math_ops.mat_mul(grad, b) - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) + grad_a = gen_math_ops.mat_mul(grad, b, grad_a=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, grad_b=True) elif t_a and not t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) - grad_b = gen_math_ops.mat_mul(a, grad) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True, grad_a=True) + grad_b = gen_math_ops.mat_mul(a, grad, grad_b=True) elif t_a and t_b: - grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) - grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) + grad_a = gen_math_ops.mat_mul( + b, grad, transpose_a=True, transpose_b=True, grad_a=True + ) + grad_b = gen_math_ops.mat_mul( + grad, a, transpose_a=True, transpose_b=True, grad_b=True + ) return grad_a, grad_b @@ -1833,18 +1841,34 @@ def _BatchMatMulV2(op: ops.Operation, grad): if not adj_x: if not adj_y: - grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) - grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) + grad_x = math_ops.matmul( + grad, y, adjoint_a=False, adjoint_b=True, grad_a=True + ) + grad_y = math_ops.matmul( + x, grad, adjoint_a=True, adjoint_b=False, grad_b=True + ) else: - grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) - grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) + grad_x = math_ops.matmul( + grad, y, adjoint_a=False, adjoint_b=False, grad_a=True + ) + grad_y = math_ops.matmul( + grad, x, adjoint_a=True, adjoint_b=False, grad_b=True + ) else: if not adj_y: - grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) - grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) + grad_x = math_ops.matmul( + y, grad, adjoint_a=False, adjoint_b=True, grad_a=True + ) + grad_y = math_ops.matmul( + x, grad, adjoint_a=False, adjoint_b=False, grad_b=True + ) else: - grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) - grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) + grad_x = math_ops.matmul( + y, grad, adjoint_a=True, adjoint_b=True, grad_a=True + ) + grad_y = math_ops.matmul( + grad, x, adjoint_a=True, adjoint_b=True, grad_b=True + ) # Possibly reduce along the broadcasted batch dimensions, if broadcasting # is required. diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 7645eaaae39311..29c695c0da2a40 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -68,7 +68,6 @@ API docstring: tensorflow.math """ import builtins -import numbers import numpy as np from tensorflow.python.eager import context @@ -76,6 +75,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import override_binary_operator from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry @@ -89,18 +89,17 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import tensor_math_operator_overrides # pylint: disable=unused-import # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import -from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch from tensorflow.python.util import nest -from tensorflow.python.util import tf_decorator -from tensorflow.python.util import traceback_utils from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import tf_export @@ -233,11 +232,6 @@ def linspace_nd(start, stop, num, name=None, axis=0): tf_export(v1=["arg_min"])(dispatch.add_dispatch_support(arg_min)) -# This is set by resource_variable_ops.py. It is included in this way since -# there is a circular dependency between math_ops and resource_variable_ops -_resource_variable_type = None - - def _set_doc(doc): def _decorator(func): @@ -997,8 +991,9 @@ def cast(x, dtype, name=None): """ base_type = dtypes.as_dtype(dtype).base_dtype - if isinstance( - x, (tensor_lib.Tensor, _resource_variable_type)) and base_type == x.dtype: + if ( + isinstance(x, tensor_lib.Tensor) or _pywrap_utils.IsResourceVariable(x) + ) and base_type == x.dtype: return x with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): @@ -1388,150 +1383,6 @@ def to_complex128(x, name="ToComplex128"): return cast(x, dtypes.complex128, name=name) -tensor_lib.Tensor._override_operator("__neg__", gen_math_ops.neg) -tensor_lib.Tensor._override_operator("__abs__", abs) - - -def _maybe_get_dtype(x): - """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" - # Don't put np.ndarray in this list, because np.result_type looks at the - # value (not just dtype) of np.ndarray to decide the result type. - if isinstance(x, numbers.Real): - return x - if isinstance(x, tensor_lib.Tensor): - return x.dtype.as_numpy_dtype - if isinstance(x, dtypes.DType): - return x.as_numpy_dtype - if isinstance(x, tensor_shape.TensorShape): - return np.int32 - if isinstance(x, (list, tuple)): - raise ValueError(f"Cannot determine dtype. Got sequence {x}.") - return x - - -def maybe_promote_tensors(*tensors, force_same_dtype=False): - """Promotes tensors if numpy style promotion is enabled. - - This function promotes `tensors` according to numpy promotion rules - if numpy style promotion is enabled. Otherwise, if - `force_same_dtype` is `True`, it force-casts `tensors[1:]` to - `tensor[0]`'s dtype. Note that this force-cast can be problematic. - For example, when some `tensors[1:]` elements can be silently - downcasted. - - Args: - *tensors: the list of tensors to promote. - force_same_dtype: bool (optional, default to `False`). When numpy - style promotion is disabled and `force_same_dtype` is `True`, - this function will force-casts `tensors[1:]` to `tensor[0]`'s - dtype (which could be problematic). - - Returns: - The promoted list of tensors. - """ - if ops.is_auto_dtype_conversion_enabled(): - return tensors - if not tensors: - return tensors - if not ops.is_numpy_style_type_promotion(): - if not force_same_dtype: - return tensors - promoted_tensors = [] - promoted_tensors.append(tensors[0]) - dtype = tensors[0].dtype.base_dtype - for tensor in tensors[1:]: - promoted_tensors.append( - ops.convert_to_tensor(tensor, dtype, name="x")) - return promoted_tensors - result_type = np_dtypes._result_type( - *[_maybe_get_dtype(x) for x in nest.flatten(tensors)]) - def _promote_or_cast(x): - if isinstance(x, tensor_lib.Tensor): - x = cast(x, result_type) - else: - x = ops.convert_to_tensor(x, result_type) - return x - return [_promote_or_cast(x) for x in tensors] - - -def _OverrideBinaryOperatorHelper( - func, op_name, clazz_object=tensor_lib.Tensor): - """Register operators with different tensor and scalar versions. - - If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices, - sp_values, sp_shape, dense)` and outputs `(new_sp_values)`. - - Args: - func: the operator - op_name: name of the operator being overridden - clazz_object: class to override for. Either `Tensor` or `SparseTensor`. - """ - - @traceback_utils.filter_traceback - def binary_op_wrapper(x, y): - with ops.name_scope(None, op_name, [x, y]) as name: - try: - # force_same_dtype=False to preserve existing TF behavior - # TODO(b/178860388): Figure out why binary_op_wrapper and - # r_binary_op_wrapper use different force_same_dtype values. - x, y = maybe_promote_tensors(x, y) - return func(x, y, name=name) - except (TypeError, ValueError) as e: - # Even if dispatching the op failed, the RHS may be a tensor aware - # object that can implement the operator with knowledge of itself - # and the tensor. - # If the RHS is not tensor aware we still want to raise the - # original error from the LHS, because it may be more - # informative. - if hasattr(type(y), "__r%s__" % op_name): - try: - r_op = getattr(y, "__r%s__" % op_name) - out = r_op(x) - if out is NotImplemented: - raise - return out - except (TypeError, ValueError): - raise e - else: - raise - - @traceback_utils.filter_traceback - def binary_op_wrapper_sparse(sp_x, y): - with ops.name_scope(None, op_name, [sp_x, y]) as name: - y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y") - return sparse_tensor.SparseTensor( - sp_x.indices, - func(sp_x.indices, sp_x.values, sp_x.dense_shape, y, name=name), - sp_x.dense_shape) - - @traceback_utils.filter_traceback - def r_binary_op_wrapper(y, x): - with ops.name_scope(None, op_name, [x, y]) as name: - # TODO(b/178860388): Figure out why binary_op_wrapper and - # r_binary_op_wrapper use different force_same_dtype values. - y, x = maybe_promote_tensors(y, x, force_same_dtype=True) - return func(x, y, name=name) - - # Propagate func.__doc__ to the wrappers - try: - doc = func.__doc__ - except AttributeError: - doc = None - binary_op_wrapper.__doc__ = doc - r_binary_op_wrapper.__doc__ = doc - binary_op_wrapper_sparse.__doc__ = doc - - if clazz_object is tensor_lib.Tensor: - clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper) - del binary_op_wrapper - clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper) - del r_binary_op_wrapper - else: - clazz_object._override_operator("__%s__" % op_name, - binary_op_wrapper_sparse) - del binary_op_wrapper_sparse - - # Conversion table for __truediv__. None entries mean no conversion required. _TRUEDIV_TABLE = { dtypes.uint8: dtypes.float32, @@ -1551,33 +1402,6 @@ def r_binary_op_wrapper(y, x): } -# NOTE: the support of "sparse (true)div dense" is currently not baked in into -# "tf.(true_)div()". Until such an API decision is made, the supported usage is -# to explicitly use the "/" operator to invoke either truediv or div. -def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None): - """Internal helper function for 'sp_t / dense_t'.""" - with ops.name_scope(name, "truediv", - [sp_indices, sp_values, sp_shape, y]) as name: - sp_values = ops.convert_to_tensor(sp_values, name="sp_values") - y = ops.convert_to_tensor(y, name="y") - x_dtype = sp_values.dtype.base_dtype - y_dtype = y.dtype.base_dtype - if x_dtype != y_dtype: - raise TypeError(f"`x` and `y` must have the same dtype, " - f"got {x_dtype!r} != {y_dtype!r}.") - try: - dtype = _TRUEDIV_TABLE[x_dtype] - except KeyError: - raise TypeError( - f"Invalid dtype {x_dtype!r} in __truediv__. Expected one " - f"of {{{', '.join([repr(x) for x in _TRUEDIV_TABLE.keys()])}}}.") - if dtype is not None: - sp_values = cast(sp_values, dtype) - y = cast(y, dtype) - return gen_sparse_ops.sparse_dense_cwise_div( - sp_indices, sp_values, sp_shape, y, name=name) - - def _truediv_python3(x, y, name=None): with ops.name_scope(name, "truediv", [x, y]) as name: x = ops.convert_to_tensor(x, name="x") @@ -1881,26 +1705,6 @@ def _mul_dispatch(x, y, name=None): return multiply(x, y, name=name) -# NOTE(aselle): When integer division is added for sparse_dense_cwise, -# div, truediv, and floordiv should be delegated appropriately for -# Python semantics, analogous to dense cwise tensor operations. -_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div", - sparse_tensor.SparseTensor) -_OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv", - sparse_tensor.SparseTensor) -_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", - sparse_tensor.SparseTensor) - -_OverrideBinaryOperatorHelper(_add_dispatch, "add") -_OverrideBinaryOperatorHelper(subtract, "sub") -_OverrideBinaryOperatorHelper(_mul_dispatch, "mul") -_OverrideBinaryOperatorHelper(div, "div") -_OverrideBinaryOperatorHelper(truediv, "truediv") -_OverrideBinaryOperatorHelper(floordiv, "floordiv") -_OverrideBinaryOperatorHelper(mod, "mod") -_OverrideBinaryOperatorHelper(pow, "pow") - - @tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"]) @dispatch.register_binary_elementwise_api @dispatch.add_dispatch_support @@ -1977,29 +1781,6 @@ def invert_(x, name=None): return gen_bitwise_ops.invert(x, name=name) -_OverrideBinaryOperatorHelper(and_, "and") -_OverrideBinaryOperatorHelper(or_, "or") -_OverrideBinaryOperatorHelper(xor_, "xor") -tensor_lib.Tensor._override_operator("__invert__", invert_) - - -def _promote_dtypes_decorator(fn): - def wrapper(x, y, *args, **kwargs): - x, y = maybe_promote_tensors(x, y) - return fn(x, y, *args, **kwargs) - return tf_decorator.make_decorator(fn, wrapper) - - -tensor_lib.Tensor._override_operator("__lt__", _promote_dtypes_decorator( - gen_math_ops.less)) -tensor_lib.Tensor._override_operator("__le__", _promote_dtypes_decorator( - gen_math_ops.less_equal)) -tensor_lib.Tensor._override_operator("__gt__", _promote_dtypes_decorator( - gen_math_ops.greater)) -tensor_lib.Tensor._override_operator("__ge__", _promote_dtypes_decorator( - gen_math_ops.greater_equal)) - - @tf_export("math.equal", "equal") @dispatch.register_binary_elementwise_api @dispatch.add_dispatch_support @@ -2109,7 +1890,7 @@ def tensor_equals(self, other): and ops.executing_eagerly_outside_functions() and (g is None or g.building_function) ): - self, other = maybe_promote_tensors(self, other) + self, other = override_binary_operator.maybe_promote_tensors(self, other) return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -2149,17 +1930,13 @@ def tensor_not_equals(self, other): tensor_lib.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() ): - self, other = maybe_promote_tensors(self, other) + self, other = override_binary_operator.maybe_promote_tensors(self, other) return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality return self is not other -tensor_lib.Tensor._override_operator("__eq__", tensor_equals) -tensor_lib.Tensor._override_operator("__ne__", tensor_not_equals) - - @tf_export("range") @dispatch.add_dispatch_support def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disable=redefined-builtin @@ -3616,16 +3393,20 @@ def trace(x, name=None): @tf_export("linalg.matmul", "matmul") @dispatch.add_dispatch_support -def matmul(a, - b, - transpose_a=False, - transpose_b=False, - adjoint_a=False, - adjoint_b=False, - a_is_sparse=False, - b_is_sparse=False, - output_type=None, - name=None): +def matmul( + a, + b, + transpose_a=False, + transpose_b=False, + adjoint_a=False, + adjoint_b=False, + a_is_sparse=False, + b_is_sparse=False, + output_type=None, + grad_a=False, + grad_b=False, + name=None, +): """Multiplies matrix `a` by matrix `b`, producing `a` * `b`. The inputs must, following any transpositions, be tensors of rank >= 2 @@ -3711,17 +3492,19 @@ def matmul(a, multiplication. a_is_sparse: If `True`, `a` is treated as a sparse matrix. Notice, this **does not support `tf.sparse.SparseTensor`**, it just makes optimizations - that assume most values in `a` are zero. - See `tf.sparse.sparse_dense_matmul` - for some support for `tf.sparse.SparseTensor` multiplication. + that assume most values in `a` are zero. See + `tf.sparse.sparse_dense_matmul` for some support for + `tf.sparse.SparseTensor` multiplication. b_is_sparse: If `True`, `b` is treated as a sparse matrix. Notice, this **does not support `tf.sparse.SparseTensor`**, it just makes optimizations - that assume most values in `b` are zero. - See `tf.sparse.sparse_dense_matmul` - for some support for `tf.sparse.SparseTensor` multiplication. + that assume most values in `b` are zero. See + `tf.sparse.sparse_dense_matmul` for some support for + `tf.sparse.SparseTensor` multiplication. output_type: The output datatype if needed. Defaults to None in which case the output_type is the same as input type. Currently only works when input tensors are type (u)int8 and output_type can be int32. + grad_a: Set it to `True` to hint that Tensor `a` is for the backward pass. + grad_b: Set it to `True` to hint that Tensor `b` is for the backward pass. name: Name for the operation (optional). Returns: @@ -3755,9 +3538,12 @@ def matmul(a, f"`adjoint_b`={adjoint_b}.") if context.executing_eagerly(): - if not isinstance(a, (ops.EagerTensor, _resource_variable_type)): + if not ( + isinstance(a, ops.EagerTensor) or _pywrap_utils.IsResourceVariable(a) + ): a = ops.convert_to_tensor(a, name="a") - if not isinstance(b, (ops.EagerTensor, _resource_variable_type)): + if not isinstance(b, ops.EagerTensor) or _pywrap_utils.IsResourceVariable( + b): b = ops.convert_to_tensor(b, dtype_hint=a.dtype.base_dtype, name="b") else: a = ops.convert_to_tensor(a, name="a") @@ -3790,10 +3576,25 @@ def matmul(a, adjoint_b = True if use_batch_matmul_v3: return gen_math_ops.batch_mat_mul_v3( - a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name) + a, + b, + adj_x=adjoint_a, + adj_y=adjoint_b, + Tout=output_type, + grad_x=grad_a, + grad_y=grad_b, + name=name, + ) else: return gen_math_ops.batch_mat_mul_v2( - a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name) + a, + b, + adj_x=adjoint_a, + adj_y=adjoint_b, + grad_x=grad_a, + grad_y=grad_b, + name=name, + ) # Neither matmul nor sparse_matmul support adjoint, so we conjugate # the matrix and use transpose instead. Conj() is a noop for real @@ -3837,10 +3638,25 @@ def matmul(a, adjoint_a = adjoint_a or transpose_a adjoint_b = adjoint_b or transpose_b return gen_math_ops.batch_mat_mul_v3( - a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name) + a, + b, + adj_x=adjoint_a, + adj_y=adjoint_b, + Tout=output_type, + grad_x=grad_a, + grad_y=grad_b, + name=name, + ) else: return gen_math_ops.mat_mul( - a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name) + a, + b, + transpose_a=transpose_a, + transpose_b=transpose_b, + grad_a=grad_a, + grad_b=grad_b, + name=name, + ) @tf_export("linalg.matvec") @@ -3884,7 +3700,7 @@ def matvec(a, b = tf.constant([7, 9, 11], shape=[3]) # `a` * `b` - # [ 58, 64] + # [ 58, 139] c = tf.linalg.matvec(a, b) @@ -3950,7 +3766,6 @@ def matmul_wrapper(a, b, name=None): # pylint: disable=missing-function-docstri return a._matmul(b) return matmul(a, b, name=name) matmul_wrapper.__doc__ = matmul.__doc__ -_OverrideBinaryOperatorHelper(matmul_wrapper, "matmul") sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")( gen_math_ops.sparse_mat_mul) diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 1131ec377fac18..e4599cbb83a5eb 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -406,7 +406,7 @@ def _runtests(self, x_shape, is_training, gradient_test=False, else: data_format_list = ['NCDHW', 'NDHWC'] use_gpu_vals = [False] - if test.is_gpu_available(cuda_only=True) and not cpu_only: + if test.is_gpu_available() and not cpu_only: use_gpu_vals += [True] factors = [1.0, 0.6] for dtype in [np.float16, np.float32, dtypes.bfloat16.as_numpy_dtype]: @@ -594,7 +594,7 @@ def _testBatchNormGradGrad(self, config): data_format_nhwc, features_nhwc = 'NDHWC', shape[4] data_format_nchw, features_nchw = 'NCDHW', shape[1] for is_training in [True, False]: - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(): self._test_grad_grad( shape, dtype, [features_nhwc], diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index cee25369963fda..e62eb4c075fcf1 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -2773,9 +2773,6 @@ def loop_fn(i): (fft_ops.rfft2d,), (fft_ops.rfft3d,), ) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message="Disable subtest on ROCm due to rocfft issues") def test_rfft(self, op_func): for dtype in (dtypes.float32, dtypes.float64): x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype) @@ -2794,9 +2791,6 @@ def loop_fn(i): (fft_ops.irfft2d,), (fft_ops.irfft3d,), ) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message="Disable subtest on ROCm due to rocfft issues") def test_irfft(self, op_func): if config.list_physical_devices("GPU"): # TODO(b/149957923): The test is flaky diff --git a/tensorflow/python/ops/ragged/__init__.py b/tensorflow/python/ops/ragged/__init__.py index c9d9a79dad753f..457e54641c6953 100644 --- a/tensorflow/python/ops/ragged/__init__.py +++ b/tensorflow/python/ops/ragged/__init__.py @@ -25,3 +25,4 @@ API docstring: tensorflow.ragged """ +from tensorflow.python.ops.ragged import ragged_tensor diff --git a/tensorflow/python/ops/ragged/ragged_cross_op_test.py b/tensorflow/python/ops/ragged/ragged_cross_op_test.py index c098c13644f342..ce3dc913f35e3d 100644 --- a/tensorflow/python/ops/ragged/ragged_cross_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_cross_op_test.py @@ -475,7 +475,7 @@ def testRaggedValuesAndSplitsMustMatch(self): def testRaggedCrossInvalidRaggedSplits(self, ragged_splits): # Test case in GitHub isseu 59114. with self.assertRaisesRegex( - (ValueError, errors.InvalidArgumentError), 'Invalid RaggedTensor' + (ValueError, errors.InvalidArgumentError), 'Invalid ragged splits' ): ragged_values_0_tensor = ops.convert_to_tensor(np.ones([3], dtype=str)) ragged_values_0 = array_ops.identity(ragged_values_0_tensor) diff --git a/tensorflow/python/ops/ragged/ragged_getitem_test.py b/tensorflow/python/ops/ragged/ragged_getitem_test.py index f707f9f5620e2e..9c46fb0c9c771e 100644 --- a/tensorflow/python/ops/ragged/ragged_getitem_test.py +++ b/tensorflow/python/ops/ragged/ragged_getitem_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import tensor_getitem_override from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor @@ -289,14 +290,16 @@ def testWithStridedSlices(self, start, stop): 'Cannot index into an inner ragged dimension'), # Tests for type errors - (SLICE_BUILDER[0.5], TypeError, re.escape(array_ops._SLICE_TYPE_ERROR)), + (SLICE_BUILDER[0.5], TypeError, re.escape( + tensor_getitem_override._SLICE_TYPE_ERROR)), (SLICE_BUILDER[1:3:0.5], TypeError, re.escape( - array_ops._SLICE_TYPE_ERROR)), + tensor_getitem_override._SLICE_TYPE_ERROR)), (SLICE_BUILDER[:, 1:3:0.5], TypeError, 'slice strides must be integers or None'), (SLICE_BUILDER[:, 0.5:1.5], TypeError, 'slice offsets must be integers or None'), - (SLICE_BUILDER['foo'], TypeError, re.escape(array_ops._SLICE_TYPE_ERROR)), + (SLICE_BUILDER['foo'], TypeError, re.escape( + tensor_getitem_override._SLICE_TYPE_ERROR)), (SLICE_BUILDER[:, 'foo':'foo'], TypeError, 'slice offsets must be integers or None'), diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py index ef98fb344aed2d..fac49983845728 100644 --- a/tensorflow/python/ops/ragged/ragged_math_ops.py +++ b/tensorflow/python/ops/ragged/ragged_math_ops.py @@ -37,9 +37,9 @@ from tensorflow.python.util.tf_export import tf_export -#=============================================================================== +# =============================================================================== # ragged.range -#=============================================================================== +# =============================================================================== # pylint: disable=redefined-builtin @tf_export('ragged.range') @dispatch.add_dispatch_support @@ -124,9 +124,9 @@ def _infer_matching_dtype(tensors, dtype_hierarchy): ops.no_gradient('RaggedRange') -#=============================================================================== +# =============================================================================== # ragged_segment_ -#=============================================================================== +# =============================================================================== # Docstring template used for the raggged_segment_ ops. _RAGGED_SEGMENT_DOCSTRING = """\ @@ -374,9 +374,9 @@ def _set_ragged_segment_docstring(func, combination, combined): _set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', 'summed') -#=============================================================================== +# =============================================================================== # ragged_reduce_ -#=============================================================================== +# =============================================================================== # Docstring template used for ragged_reduce_ ops. _RAGGED_REDUCE_DOCSTRING = """\ @@ -707,7 +707,8 @@ def reduce_variance(input_tensor: ragged_tensor.Ragged, input_tensor, name='input_tensor') if input_tensor.dtype.is_complex: raise ValueError( - 'reduce_variance is not supported for RaggedTensors with complex dtypes.' + 'reduce_variance is not supported for RaggedTensors with complex' + ' dtypes.' ) square_of_input = math_ops.square(input_tensor) mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims) @@ -788,20 +789,24 @@ def _set_ragged_reduce_docstring(func, combination, combined, default, example): _RAGGED_REDUCE_ANY_EXAMPLE) -#=============================================================================== +# =============================================================================== # ragged.matmul -#=============================================================================== +# =============================================================================== @dispatch.dispatch_for_api(math_ops.matmul) -def matmul(a: ragged_tensor.RaggedOrDense, - b: ragged_tensor.RaggedOrDense, - transpose_a=False, - transpose_b=False, - adjoint_a=False, - adjoint_b=False, - a_is_sparse=False, - b_is_sparse=False, - output_type=None, - name=None): +def matmul( + a: ragged_tensor.RaggedOrDense, + b: ragged_tensor.RaggedOrDense, + transpose_a=False, + transpose_b=False, + adjoint_a=False, + adjoint_b=False, + a_is_sparse=False, + b_is_sparse=False, + output_type=None, + grad_a=False, + grad_b=False, + name=None, +): """Multiplies matrix `a` by matrix `b`. If all transpose or adjoint attributes are `False` then: @@ -824,6 +829,8 @@ def matmul(a: ragged_tensor.RaggedOrDense, a_is_sparse: If `True`, optimize assuming `a` is mostly zero. b_is_sparse: If `True`, optimize assuming `b` is mostly zero. output_type: The output datatype (optional). + grad_a: Unused. + grad_b: Unused. name: Name for the operation (optional). Returns: @@ -831,6 +838,8 @@ def matmul(a: ragged_tensor.RaggedOrDense, each inner-most matrix is the product of the corresponding matrices in `a` and `b`. """ + del grad_a + del grad_b if transpose_a and adjoint_a: raise ValueError('Only one of transpose_a and adjoint_a can be True.') if transpose_b and adjoint_b: @@ -1029,9 +1038,9 @@ def _matmul_3d_with_batch_dim_folding(a, b, **kwargs): return a.with_values(array_ops.squeeze(flat_result, axis=1)) -#=============================================================================== +# =============================================================================== # ragged.softmax -#=============================================================================== +# =============================================================================== @dispatch.dispatch_for_api(nn_ops.softmax_v2) def softmax(logits: ragged_tensor.Ragged, axis=None, name=None): """Computes softmax activations. @@ -1076,9 +1085,9 @@ def softmax(logits: ragged_tensor.Ragged, axis=None, name=None): return math_ops.divide(logits_exp, denominator) -#=============================================================================== +# =============================================================================== # ragged.add_n -#=============================================================================== +# =============================================================================== @dispatch.dispatch_for_api(math_ops.add_n) def add_n(inputs: typing.List[ragged_tensor.RaggedOrDense], name=None): """RaggedTensor implementation for tf.math.add_n.""" @@ -1088,9 +1097,9 @@ def add_n(inputs: typing.List[ragged_tensor.RaggedOrDense], name=None): return ragged_functional_ops.map_flat_values(math_ops.add_n, inputs) -#=============================================================================== +# =============================================================================== # Ragged version of nn_ops.dropout -#=============================================================================== +# =============================================================================== @dispatch.dispatch_for_api(nn_ops.dropout) def dropout_v1(x: ragged_tensor.Ragged, keep_prob=None, @@ -1140,9 +1149,9 @@ def stateless_dropout(x: ragged_tensor.Ragged, x.flat_values, rate=rate, seed=seed, rng_alg=rng_alg)) -#=============================================================================== +# =============================================================================== # Ragged version of Tensor.__eq__ and Tensor.__ne__ -#=============================================================================== +# =============================================================================== @dispatch.dispatch_for_api(math_ops.tensor_equals) def tensor_equals(self: ragged_tensor.RaggedOrDense, other: ragged_tensor.RaggedOrDense): diff --git a/tensorflow/python/ops/ref_variable.py b/tensorflow/python/ops/ref_variable.py index 7e51288b48ef9d..241275b44da30f 100644 --- a/tensorflow/python/ops/ref_variable.py +++ b/tensorflow/python/ops/ref_variable.py @@ -97,9 +97,6 @@ def default_variable_creator(next_creator=None, **kwargs): shape=shape) -variable_v1.default_variable_creator = default_variable_creator - - def _to_proto_fn(v, export_scope=None): """Converts Variable and ResourceVariable to VariableDef for collections.""" return v.to_proto(export_scope=export_scope) @@ -1346,6 +1343,3 @@ def _restore_from_tensors(self, restored_tensors): # allowing instances of the class to be used as tensors. tensor_conversion_registry.register_tensor_conversion_function( RefVariable, RefVariable._TensorConversionFunction) # pylint: disable=protected-access - - -variable_v1.set_variable_from_proto_fn(RefVariable) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 6e1a6b6280b10a..bc5011178cef7c 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -49,7 +49,6 @@ from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import handle_data_util -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables # go/tf-wildcard-import @@ -59,7 +58,6 @@ from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.trackable import base as trackable from tensorflow.python.types import core -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -372,9 +370,6 @@ def default_variable_creator_v2(next_creator=None, **kwargs): ) -variables.default_variable_creator_v2 = default_variable_creator_v2 - - class BaseResourceVariable(variables.Variable, core.Tensor): """A python variable from an existing handle.""" @@ -2332,10 +2327,6 @@ def __init__( # pylint: disable=super-init-not-called in_graph_mode=self._in_graph_mode, **unused_kwargs) -_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) -math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access - - def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -2772,9 +2763,6 @@ def __eq__(self, other): ) -_pywrap_utils.RegisterType("VariableSpec", VariableSpec) - - def write_object_proto_for_resource_variable(resource_variable, proto, options, diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index fcfa8a8b18c260..761f42885ada59 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -28,7 +28,6 @@ from tensorflow.dtensor.python import api as dtensor_api from tensorflow.dtensor.python import layout as layout_lib from tensorflow.python.eager import context -from tensorflow.python.eager import profiler as _profiler from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -43,6 +42,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_op_util from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.profiler import profiler_v2 as _profiler from tensorflow.python.trackable import resource from tensorflow.python.training import training_util from tensorflow.python.util import deprecation @@ -1327,7 +1327,7 @@ def run_metadata_graphs(name, data, step=None): @tf_export("summary.trace_on", v1=[]) -def trace_on(graph=True, profiler=False): # pylint: disable=redefined-outer-name +def trace_on(graph=True, profiler=False, profiler_outdir=None): # pylint: disable=redefined-outer-name """Starts a trace to record computation graphs and profiling information. Must be invoked in eager mode. @@ -1342,12 +1342,13 @@ def trace_on(graph=True, profiler=False): # pylint: disable=redefined-outer-nam Args: graph: If True, enables collection of executed graphs. It includes ones from - tf.function invocation and ones from the legacy graph mode. The default - is True. + tf.function invocation and ones from the legacy graph mode. The default is + True. profiler: If True, enables the advanced profiler. Enabling profiler - implicitly enables the graph collection. The profiler may incur a high - memory overhead. The default is False. - + implicitly enables the graph collection. The profiler may incur a high + memory overhead. The default is False. + profiler_outdir: Output directory for profiler. It is required when profiler + is enabled when trace was started. Otherwise, it is ignored. """ if ops.inside_function(): logging.warn("Cannot enable trace inside a tf.function.") @@ -1365,12 +1366,22 @@ def trace_on(graph=True, profiler=False): # pylint: disable=redefined-outer-nam if graph and not profiler: context.context().enable_graph_collection() if profiler: - context.context().enable_run_metadata() - _profiler.start() + if profiler_outdir is None: + # TODO(b/149431324): Change this to throw a ValueError when Tensorflow + # major version advances. (current version is 2.15) + logging.warn( + "No `profiler_outdir` passed to trace_on(). Profiler won't be" + " enabled." + ) + else: + context.context().enable_run_metadata() + _profiler.start(profiler_outdir) _current_trace_context = _TraceContext(graph=graph, profiler=profiler) +# TODO(b/149431324): Delete `profiler_outdir` arg when Tensorflow major version +# advances. (current version is 2.15) @tf_export("summary.trace_export", v1=[]) def trace_export(name, step=None, profiler_outdir=None): """Stops and exports the active trace as a Summary and/or profile file. @@ -1383,8 +1394,7 @@ def trace_export(name, step=None, profiler_outdir=None): step: Explicit `int64`-castable monotonic step value for this summary. If omitted, this defaults to `tf.summary.experimental.get_step()`, which must not be None. - profiler_outdir: Output directory for profiler. It is required when profiler - is enabled when trace was started. Otherwise, it is ignored. + profiler_outdir: This arg is a no-op. Please set this in trace_on(). Raises: ValueError: if a default writer exists, but no step was provided and @@ -1406,8 +1416,6 @@ def trace_export(name, step=None, profiler_outdir=None): raise ValueError("Must enable trace before export through " "tf.summary.trace_on.") graph, profiler = _current_trace_context # pylint: disable=redefined-outer-name - if profiler and profiler_outdir is None: - raise ValueError("Argument `profiler_outdir` is not specified.") run_meta = context.context().export_run_metadata() @@ -1417,7 +1425,12 @@ def trace_export(name, step=None, profiler_outdir=None): run_metadata(name, run_meta, step) if profiler: - _profiler.save(profiler_outdir, _profiler.stop()) + if profiler_outdir: + logging.warn( + "Ignoring `profiler_outdir` passed to trace_export(). Please pass it" + " to trace_on() instead." + ) + _profiler.stop() trace_off() @@ -1439,7 +1452,8 @@ def trace_off(): if profiler: try: _profiler.stop() - except _profiler.ProfilerNotRunningError: + except Exception as e: # pylint: disable=broad-except + logging.warn("Error while stopping profiler: %s", e) pass diff --git a/tensorflow/python/ops/tensor_getitem_override.py b/tensorflow/python/ops/tensor_getitem_override.py new file mode 100644 index 00000000000000..67d71ae4a3c8f4 --- /dev/null +++ b/tensorflow/python/ops/tensor_getitem_override.py @@ -0,0 +1,314 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Tests for this file live in python/kernel_tests/array_ops_test.py +"""Tensor __getitem__ override logic.""" + +import numbers +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.util import dispatch +from tensorflow.python.util.tf_export import tf_export + + +# We override the 'slice' for the "slice" op, so we keep Python's +# existing 'slice' for later use in this module. +_BaseSlice = slice + + +_SLICE_TYPE_ERROR = ( + "Only integers, slices (`:`), ellipsis (`...`), " + "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid " + "indices") + + +_SUPPORTED_SLICE_DTYPES = (dtypes.int16, dtypes.int32, dtypes.int32_ref, + dtypes.int64, dtypes.int64_ref) + + +def _is_undefined_dimension(d): + return isinstance(d, tensor_shape.Dimension) and d.value is None + + +def _check_index(idx): + """Check if a given value is a valid index into a tensor.""" + if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): + return + + # Optimistic check. Assumptions: + # * any object with a dtype is supported + # * any object with a dtype has a sizeable shape attribute. + dtype = getattr(idx, "dtype", None) + if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or + idx.shape and len(idx.shape) == 1): + # TODO(slebedev): IndexError seems more appropriate here, but it + # will break `_slice_helper` contract. + raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx)) + + +@tf_export("__operators__.getitem", v1=[]) +@dispatch.add_dispatch_support +def _slice_helper(tensor, slice_spec, var=None): + """Overload for Tensor.__getitem__. + + This operation extracts the specified region from the tensor. + The notation is similar to NumPy with the restriction that + currently only support basic indexing. That means that + using a non-scalar tensor as input is not currently allowed. + + Some useful examples: + + ```python + # Strip leading and trailing 2 elements + foo = tf.constant([1,2,3,4,5,6]) + print(foo[2:-2]) # => [3,4] + + # Skip every other row and reverse the order of the columns + foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) + print(foo[::2,::-1]) # => [[3,2,1], [9,8,7]] + + # Use scalar tensors as indices on both dimensions + print(foo[tf.constant(0), tf.constant(2)]) # => 3 + + # Insert another dimension + foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) + print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]] + print(foo[:, tf.newaxis, :]) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]] + print(foo[:, :, tf.newaxis]) # => [[[1],[2],[3]], [[4],[5],[6]], + [[7],[8],[9]]] + + # Ellipses (3 equivalent operations) + foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) + print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]] + print(foo[tf.newaxis, ...]) # => [[[1,2,3], [4,5,6], [7,8,9]]] + print(foo[tf.newaxis]) # => [[[1,2,3], [4,5,6], [7,8,9]]] + + # Masks + foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) + print(foo[foo > 2]) # => [3, 4, 5, 6, 7, 8, 9] + ``` + + Notes: + - `tf.newaxis` is `None` as in NumPy. + - An implicit ellipsis is placed at the end of the `slice_spec` + - NumPy advanced indexing is currently not supported. + + Purpose in the API: + + This method is exposed in TensorFlow's API so that library developers + can register dispatching for `Tensor.__getitem__` to allow it to handle + custom composite tensors & other custom objects. + + The API symbol is not intended to be called by users directly and does + appear in TensorFlow's generated documentation. + + Args: + tensor: An tensor.Tensor object. + slice_spec: The arguments to Tensor.__getitem__. + var: In the case of variable slice assignment, the Variable object to slice + (i.e. tensor is the read-only view of this variable). + + Returns: + The appropriate slice of "tensor", based on "slice_spec". + + Raises: + ValueError: If a slice range is negative size. + TypeError: If the slice indices aren't int, slice, ellipsis, + tf.newaxis or scalar int32/int64 tensors. + """ + from tensorflow.python.framework import constant_op # pylint: disable=g-import-not-at-top + from tensorflow.python.ops import array_ops # pylint: disable=g-import-not-at-top + tensor = ops.convert_to_tensor(tensor) + # TODO(wangpeng): Consider supporting var + if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access + return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access + + if (isinstance(slice_spec, bool) + or (isinstance(slice_spec, tensor_lib.Tensor) + and slice_spec.dtype == dtypes.bool) + or (isinstance(slice_spec, np.ndarray) + and slice_spec.dtype == bool)): + return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) + + if not isinstance(slice_spec, (list, tuple)): + slice_spec = [slice_spec] + + begin, end, strides = [], [], [] + index = 0 + + new_axis_mask, shrink_axis_mask = 0, 0 + begin_mask, end_mask = 0, 0 + ellipsis_mask = 0 + for s in slice_spec: + if isinstance(s, _BaseSlice): + # Finds the best dtype for begin, end, and strides. + dtype = None + for t in [s.start, s.stop, s.step]: + if t is None or not isinstance(t, tensor_lib.Tensor): + continue + if t.dtype == dtypes.int64: + dtype = dtypes.int64 + elif t.dtype == dtypes.int32 and dtype != dtypes.int64: + dtype = dtypes.int32 + elif t.dtype == dtypes.int16 and dtype is None: + dtype = dtypes.int16 + + if s.start is not None and not _is_undefined_dimension(s.start): + _check_index(s.start) + begin.append(s.start) + else: + if dtype is not None: + begin.append(constant_op.constant(0, dtype=dtype)) + else: + begin.append(0) + begin_mask |= (1 << index) + if s.stop is not None and not _is_undefined_dimension(s.stop): + _check_index(s.stop) + end.append(s.stop) + else: + if dtype is not None: + end.append(constant_op.constant(0, dtype=dtype)) + else: + end.append(0) + end_mask |= (1 << index) + if s.step is not None and not _is_undefined_dimension(s.step): + _check_index(s.step) + strides.append(s.step) + else: + if dtype is not None: + strides.append(constant_op.constant(1, dtype=dtype)) + else: + strides.append(1) + elif s is Ellipsis: + begin.append(0) + end.append(0) + strides.append(1) + ellipsis_mask |= (1 << index) + elif s is array_ops.newaxis: + begin.append(0) + end.append(0) + strides.append(1) + new_axis_mask |= (1 << index) + else: + _check_index(s) + begin.append(s) + end.append(s + 1) + # TODO(mdan): Investigate why we can't set int32 here. + if ( + isinstance(s, tensor_lib.Tensor) + and (s.dtype == dtypes.int16 or s.dtype == dtypes.int64)): + strides.append(constant_op.constant(1, dtype=s.dtype)) + else: + strides.append(1) + shrink_axis_mask |= (1 << index) + index += 1 + + # stack possibly involves no tensors, so we must use op_scope correct graph. + with ops.name_scope( + None, + "strided_slice", [tensor] + begin + end + strides, + skip_on_eager=False) as name: + if begin: + from tensorflow.python.ops import array_ops_stack # pylint: disable=g-import-not-at-top + packed_begin, packed_end, packed_strides = ( + array_ops_stack.stack(begin), + array_ops_stack.stack(end), + array_ops_stack.stack(strides)) + # TODO(mdan): Instead of implicitly casting, it's better to enforce the + # same dtypes. + if (packed_begin.dtype == dtypes.int64 or + packed_end.dtype == dtypes.int64 or + packed_strides.dtype == dtypes.int64): + if packed_begin.dtype != dtypes.int64: + packed_begin = gen_math_ops.cast(packed_begin, dtypes.int64) + if packed_end.dtype != dtypes.int64: + packed_end = gen_math_ops.cast(packed_end, dtypes.int64) + if packed_strides.dtype != dtypes.int64: + packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64) + elif (packed_begin.dtype == dtypes.int16 and + packed_end.dtype == dtypes.int16 and + packed_strides.dtype == dtypes.int16): + if packed_begin.dtype != dtypes.int16: + packed_begin = gen_math_ops.cast(packed_begin, dtypes.int16) + if packed_end.dtype != dtypes.int16: + packed_end = gen_math_ops.cast(packed_end, dtypes.int16) + if packed_strides.dtype != dtypes.int16: + packed_strides = gen_math_ops.cast(packed_strides, dtypes.int16) + else: + var_empty = constant_op.constant([], dtype=dtypes.int32) + packed_begin = packed_end = packed_strides = var_empty + return array_ops.strided_slice( + tensor, + packed_begin, + packed_end, + packed_strides, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + var=var, + name=name) + + +def _slice_helper_var(var, slice_spec): + """Creates a slice helper object given a variable. + + This allows creating a sub-tensor from part of the current contents + of a variable. See `tf.Tensor.__getitem__` for detailed examples + of slicing. + + This function in addition also allows assignment to a sliced range. + This is similar to `__setitem__` functionality in Python. However, + the syntax is different so that the user can capture the assignment + operation for grouping or passing to `sess.run()` in TF1. + For example, + + ```python + import tensorflow as tf + A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) + print(A[:2, :2]) # => [[1,2], [4,5]] + + A[:2,:2].assign(22. * tf.ones((2, 2)))) + print(A) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] + ``` + + Note that assignments currently do not support NumPy broadcasting + semantics. + + Args: + var: An `ops.Variable` object. + slice_spec: The arguments to `Tensor.__getitem__`. + + Returns: + The appropriate slice of "tensor", based on "slice_spec". + As an operator. The operator also has a `assign()` method + that can be used to generate an assignment operator. + + Raises: + ValueError: If a slice range is negative size. + TypeError: TypeError: If the slice indices aren't int, slice, + ellipsis, tf.newaxis or int32/int64 tensors. + + """ + + return _slice_helper(var.value(), slice_spec, var) + + +tensor_lib.Tensor._override_operator("__getitem__", _slice_helper) # pylint: disable=protected-access diff --git a/tensorflow/python/ops/tensor_math_operator_overrides.py b/tensorflow/python/ops/tensor_math_operator_overrides.py new file mode 100644 index 00000000000000..f94d2a14da8faa --- /dev/null +++ b/tensorflow/python/ops/tensor_math_operator_overrides.py @@ -0,0 +1,168 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Overrides for Tensor operators.""" + + +from tensorflow.python.framework import override_binary_operator +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.util import tf_decorator + + +# pylint: disable=g-import-not-at-top +def _add_dispatch_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops._add_dispatch(x, y, name=name) # pylint: disable=protected-access + + +def _and_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.and_(x, y, name=name) + + +def _div_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.div(x, y, name=name) + + +def _floordiv_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.floordiv(x, y, name=name) + + +def _matmul_factory(a, b, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.matmul_wrapper(a, b, name=name) + + +def _mod_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.mod(x, y, name=name) + + +def _mul_dispatch_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops._mul_dispatch(x, y, name=name) # pylint: disable=protected-access + + +def _or_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.or_(x, y, name=name) + + +def _pow_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.pow(x, y, name=name) + + +def _subtract_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.subtract(x, y, name=name) + + +def _truediv_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.truediv(x, y, name=name) + + +def _xor_factory(x, y, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.xor_(x, y, name=name) + + +override_binary_operator.override_binary_operator_helper( + _add_dispatch_factory, "add" +) +override_binary_operator.override_binary_operator_helper(_and_factory, "and") +override_binary_operator.override_binary_operator_helper(_div_factory, "div") +override_binary_operator.override_binary_operator_helper( + _floordiv_factory, "floordiv" +) +override_binary_operator.override_binary_operator_helper( + _matmul_factory, "matmul" +) +override_binary_operator.override_binary_operator_helper(_mod_factory, "mod") +override_binary_operator.override_binary_operator_helper( + _mul_dispatch_factory, "mul" +) +override_binary_operator.override_binary_operator_helper(_or_factory, "or") +override_binary_operator.override_binary_operator_helper(_pow_factory, "pow") +override_binary_operator.override_binary_operator_helper( + _subtract_factory, "sub" +) +override_binary_operator.override_binary_operator_helper( + _truediv_factory, "truediv" +) +override_binary_operator.override_binary_operator_helper(_xor_factory, "xor") + + +def _invert_factory(x, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.invert_(x, name=name) + + +def _abs_factory(x, name=None): + from tensorflow.python.ops import math_ops + + return math_ops.abs(x, name=name) + + +def _tensor_equals_factory(self, other): + from tensorflow.python.ops import math_ops + + return math_ops.tensor_equals(self, other) + + +def _tensor_not_equals_factory(self, other): + from tensorflow.python.ops import math_ops + + return math_ops.tensor_not_equals(self, other) + + +def _promote_dtypes_decorator(fn): + def wrapper(x, y, *args, **kwargs): + x, y = override_binary_operator.maybe_promote_tensors(x, y) + return fn(x, y, *args, **kwargs) + + return tf_decorator.make_decorator(fn, wrapper) + + +# pylint: disable=protected-access +tensor_lib.Tensor._override_operator("__invert__", _invert_factory) +tensor_lib.Tensor._override_operator("__neg__", gen_math_ops.neg) +tensor_lib.Tensor._override_operator("__abs__", _abs_factory) +tensor_lib.Tensor._override_operator("__lt__", _promote_dtypes_decorator( + gen_math_ops.less)) +tensor_lib.Tensor._override_operator("__le__", _promote_dtypes_decorator( + gen_math_ops.less_equal)) +tensor_lib.Tensor._override_operator("__gt__", _promote_dtypes_decorator( + gen_math_ops.greater)) +tensor_lib.Tensor._override_operator("__ge__", _promote_dtypes_decorator( + gen_math_ops.greater_equal)) +tensor_lib.Tensor._override_operator("__eq__", _tensor_equals_factory) +tensor_lib.Tensor._override_operator("__ne__", _tensor_not_equals_factory) diff --git a/tensorflow/python/ops/variable_v1.py b/tensorflow/python/ops/variable_v1.py index d7d4f0e5daeee9..f3cca80758e5cb 100644 --- a/tensorflow/python/ops/variable_v1.py +++ b/tensorflow/python/ops/variable_v1.py @@ -23,15 +23,6 @@ from tensorflow.python.util.tf_export import tf_export -_variable_from_proto_fn = None - - -def set_variable_from_proto_fn(variable_from_proto_fn): - """Set the variable class that variable proto defs will be converted to.""" - global _variable_from_proto_fn - _variable_from_proto_fn = variable_from_proto_fn - - @tf_export(v1=["is_variable_initialized"]) @tf_should_use.should_use_result def is_variable_initialized(variable): @@ -47,9 +38,12 @@ def is_variable_initialized(variable): return state_ops.is_variable_initialized(variable) -def default_variable_creator(_, **kwds): - del kwds - raise NotImplementedError("ref_variable needs to be imported") +def default_variable_creator(next_creator=None, **kwds): + from tensorflow.python.ops import ref_variable # pylint: disable=g-import-not-at-top + + return ref_variable.default_variable_creator( + next_creator=next_creator, **kwds + ) @tf_export(v1=["Variable"]) @@ -269,7 +263,8 @@ def initialized_value(self): @staticmethod def from_proto(variable_def, import_scope=None): - return _variable_from_proto_fn( + from tensorflow.python.ops import ref_variable # pylint: disable=g-import-not-at-top + return ref_variable.RefVariable( variable_def=variable_def, import_scope=import_scope) @classmethod diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 5208dd1c8229ae..49821d75da445d 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -34,8 +34,8 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import tensor_getitem_override from tensorflow.python.trackable import base as trackable -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import object_identity from tensorflow.python.util import tf_should_use from tensorflow.python.util import traceback_utils @@ -44,9 +44,11 @@ from tensorflow.python.util.tf_export import tf_export -def default_variable_creator_v2(_, **kwds): - del kwds - raise NotImplementedError("resource_variable_ops needs to be imported") +def default_variable_creator_v2(next_creator=None, **kwds): + from tensorflow.python.ops import resource_variable_ops # pylint: disable=g-import-not-at-top + + return resource_variable_ops.default_variable_creator_v2( + next_creator=next_creator, **kwds) def _make_getter(captured_getter, captured_previous): @@ -984,10 +986,10 @@ def _OverloadAllOperators(cls): # pylint: disable=invalid-name """Register overloads for all operators.""" for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS: cls._OverloadOperator(operator) - # For slicing, bind getitem differently than a tensor (use SliceHelperVar + # For slicing, bind getitem differently than a tensor (use _slice_helper_var # instead) # pylint: disable=protected-access - setattr(cls, "__getitem__", array_ops._SliceHelperVar) + setattr(cls, "__getitem__", tensor_getitem_override._slice_helper_var) @classmethod def _OverloadOperator(cls, operator): # pylint: disable=invalid-name @@ -1324,7 +1326,6 @@ def to_proto(self, export_scope=None): Variable._OverloadAllOperators() # pylint: disable=protected-access -_pywrap_utils.RegisterType("Variable", Variable) def _try_guard_against_uninitialized_dependencies(name, initial_value): diff --git a/tensorflow/python/profiler/internal/run_metadata_test.py b/tensorflow/python/profiler/internal/run_metadata_test.py index d95dcb79d1e4fd..f5df743995fb86 100644 --- a/tensorflow/python/profiler/internal/run_metadata_test.py +++ b/tensorflow/python/profiler/internal/run_metadata_test.py @@ -112,9 +112,6 @@ class RunMetadataTest(test.TestCase): # work as expected. Since we now run this test with SOFTWARE_TRACE # (see _run_model routine above), this test will / should fail since # GPU device tracers are not enabled - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message='Test fails on ROCm when run without FULL_TRACE') @test_util.run_deprecated_v1 def testGPU(self): if not test.is_gpu_available(cuda_only=True): diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 2b522cddc7ce22..3ae6d97eb18ef8 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -646,7 +646,7 @@ cuda_py_strict_test( "//tensorflow/python/trackable:resource", "//tensorflow/python/training:monitored_session", "//tensorflow/python/types:core", - "//tensorflow/python/util:tf_decorator", + "//tensorflow/python/util:tf_inspect", "@absl_py//absl/testing:parameterized", ] + if_google([ "//tensorflow/cc/experimental/tf2:runtime_pybind", @@ -770,7 +770,8 @@ py_strict_library( "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/util:compat", "//tensorflow/python/util:nest", - "//tensorflow/python/util:tf_decorator", + "//tensorflow/python/util:tf_decorator_py", + "//tensorflow/python/util:tf_inspect", "@absl_py//absl/logging", ], ) @@ -914,7 +915,6 @@ tf_python_pybind_extension( # "//tensorflow:windows": [], # }), # static_deps = tf_python_pybind_static_deps(), - features = ["-layering_check"], pytype_srcs = [ "pywrap_saved_model/__init__.pyi", "pywrap_saved_model/constants.pyi", @@ -928,17 +928,27 @@ tf_python_pybind_extension( "//tensorflow/python/training:__subpackages__", ], deps = [ - ":pywrap_saved_model_headers", # placeholder for index annotation deps "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "//tensorflow/cc/experimental/libexport:save", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:fingerprinting", + "//tensorflow/cc/saved_model:metrics", "//tensorflow/cc/saved_model:reader", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:path", "//tensorflow/python/lib/core:pybind11_status", "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:status_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], + ] + if_google([ + "//tensorflow/tools/proto_splitter:merge", + ]), ) tf_py_strict_test( diff --git a/tensorflow/python/saved_model/fingerprinting.py b/tensorflow/python/saved_model/fingerprinting.py index dd8be59cfaa694..30d9dc76987b88 100644 --- a/tensorflow/python/saved_model/fingerprinting.py +++ b/tensorflow/python/saved_model/fingerprinting.py @@ -18,13 +18,15 @@ fingerprint. """ +from typing import Any + from tensorflow.core.protobuf import fingerprint_pb2 from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as fingerprinting_pywrap from tensorflow.python.util.tf_export import tf_export @tf_export("saved_model.experimental.Fingerprint", v1=[]) -class Fingerprint(object): +class Fingerprint: """The SavedModel fingerprint. Each attribute of this class is named after a field name in the @@ -42,12 +44,12 @@ class Fingerprint(object): def __init__( self, - saved_model_checksum=None, - graph_def_program_hash=None, - signature_def_hash=None, - saved_object_graph_hash=None, - checkpoint_hash=None, - version=None, + saved_model_checksum: int = None, + graph_def_program_hash: int = None, + signature_def_hash: int = None, + saved_object_graph_hash: int = None, + checkpoint_hash: int = None, + version: int = None, ): """Initializes the instance based on values in the SavedModel fingerprint. @@ -67,7 +69,7 @@ def __init__( self.version = version @classmethod - def from_proto(cls, proto): + def from_proto(cls, proto: fingerprint_pb2.FingerprintDef) -> "Fingerprint": """Constructs Fingerprint object from protocol buffer message.""" if isinstance(proto, bytes): proto = fingerprint_pb2.FingerprintDef.FromString(proto) @@ -84,7 +86,7 @@ def from_proto(cls, proto): f"Given proto could not be deserialized as fingerprint." f"{e}") from None - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if (isinstance(other, Fingerprint) or isinstance(other, fingerprint_pb2.FingerprintDef)): try: @@ -98,9 +100,9 @@ def __eq__(self, other): pass return False - def __str__(self): + def __str__(self) -> str: return "\n".join([ - f"SavedModel Fingerprint", + "SavedModel Fingerprint", f" saved_model_checksum: {self.saved_model_checksum}", f" graph_def_program_hash: {self.graph_def_program_hash}", f" signature_def_hash: {self.signature_def_hash}", @@ -108,14 +110,14 @@ def __str__(self): f" checkpoint_hash: {self.checkpoint_hash}" ]) - def __repr__(self): + def __repr__(self) -> str: return (f"Fingerprint({self.saved_model_checksum}, " f"{self.graph_def_program_hash}, " f"{self.signature_def_hash}, " f"{self.saved_object_graph_hash}, " f"{self.checkpoint_hash})") - def singleprint(self): + def singleprint(self) -> fingerprinting_pywrap.Singleprint: """Canonical fingerprinting ID for a SavedModel. Uniquely identifies a SavedModel based on the regularized fingerprint @@ -147,7 +149,7 @@ def singleprint(self): @tf_export("saved_model.experimental.read_fingerprint", v1=[]) -def read_fingerprint(export_dir): +def read_fingerprint(export_dir: str) -> Fingerprint: """Reads the fingerprint of a SavedModel in `export_dir`. Returns a `tf.saved_model.experimental.Fingerprint` object that contains diff --git a/tensorflow/python/saved_model/fingerprinting_utils.py b/tensorflow/python/saved_model/fingerprinting_utils.py index 67ebdd33cd7704..cb31860ed81bd3 100644 --- a/tensorflow/python/saved_model/fingerprinting_utils.py +++ b/tensorflow/python/saved_model/fingerprinting_utils.py @@ -32,7 +32,7 @@ FingerprintException = fingerprinting_pywrap.FingerprintException -def write_fingerprint(export_dir): +def write_fingerprint(export_dir: str) -> None: """Write fingerprint protobuf, if requested. Writes a `tf.saved_model.experimental.Fingerprint` object to a @@ -66,7 +66,7 @@ def write_fingerprint(export_dir): "Model saving will continue.") -def singleprint_from_saved_model_proto(export_dir): +def singleprint_from_saved_model_proto(export_dir: str) -> str: """Returns the singleprint of `saved_model.pb` in `export_dir`. Args: @@ -85,7 +85,7 @@ def singleprint_from_saved_model_proto(export_dir): raise ValueError(e) from None -def singleprint_from_fingerprint_proto(export_dir): +def singleprint_from_fingerprint_proto(export_dir: str) -> str: """Returns the singleprint of `fingerprint.pb` in `export_dir`. Args: @@ -104,7 +104,7 @@ def singleprint_from_fingerprint_proto(export_dir): raise ValueError(e) from None -def singleprint_from_saved_model(export_dir): +def singleprint_from_saved_model(export_dir: str) -> str: """Returns the singleprint of the SavedModel in `export_dir`. First tries to construct the singleprint from `fingerprint.pb`, then from @@ -141,9 +141,8 @@ def singleprint_from_saved_model(export_dir): raise ValueError(e) from None -def to_proto(fingerprint): - if not isinstance(fingerprint, fingerprinting.Fingerprint): - raise TypeError("Supplied value is not a Fingerprint.") +def to_proto( + fingerprint: fingerprinting.Fingerprint) -> fingerprint_pb2.FingerprintDef: return fingerprint_pb2.FingerprintDef( saved_model_checksum=fingerprint.saved_model_checksum, graph_def_program_hash=fingerprint.graph_def_program_hash, diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 8f498936328e86..a08d41f5fcf499 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -2970,7 +2970,7 @@ def increment_v(x): # TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3 # iterations took hundreds of seconds). It would be really nice to check # allocations at a lower level. - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def test_functions_cleaned(self, use_cpp_bindings): # TODO(b/264869753) Fix SingleCycleTest if use_cpp_bindings: diff --git a/tensorflow/python/saved_model/registration/registration_saving_test.py b/tensorflow/python/saved_model/registration/registration_saving_test.py index ec87f06c0a5e10..8e60cc2bf9e122 100644 --- a/tensorflow/python/saved_model/registration/registration_saving_test.py +++ b/tensorflow/python/saved_model/registration/registration_saving_test.py @@ -223,7 +223,7 @@ def test_registered_saver(self, cycles): class SingleCycleTest(test.TestCase): - @test_util.deprecated_graph_mode_only() + @test_util.deprecated_graph_mode_only def test_registered_saver_fails_in_saved_model_graph_mode(self): with context.eager_mode(): p1 = Part([1, 4]) diff --git a/tensorflow/python/summary/BUILD b/tensorflow/python/summary/BUILD index 5ed5b0f74dc2df..7af5c5cb277ae8 100644 --- a/tensorflow/python/summary/BUILD +++ b/tensorflow/python/summary/BUILD @@ -37,6 +37,7 @@ py_strict_library( srcs = ["summary.py"], visibility = ["//visibility:public"], deps = [ + ":tb_summary", "//tensorflow/core:protos_all_py", "//tensorflow/python/distribute:summary_op_util", "//tensorflow/python/eager:context", @@ -124,3 +125,10 @@ tf_py_strict_test( "@pypi_tb_nightly//:pkg", ], ) + +py_strict_library( + name = "tb_summary", + srcs = ["tb_summary.py"], + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/python/util:tf_export"], +) diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 161456a7aecae0..b6112b1d7db1d9 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -46,7 +46,7 @@ from tensorflow.python.ops import gen_summary_ops as _gen_summary_ops # pylint: disable=unused-import from tensorflow.python.ops import summary_op_util as _summary_op_util from tensorflow.python.ops import summary_ops_v2 as _summary_ops_v2 - +from tensorflow.python.summary import tb_summary # exports FileWriter, FileWriterCache # pylint: disable=unused-import from tensorflow.python.summary.writer.writer import FileWriter @@ -124,9 +124,8 @@ def scalar(name, tensor, collections=None, family=None): if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import scalar as scalar_v2 # pylint: disable=g-import-not-at-top with _compat_summary_scope(name, family) as tag: - scalar_v2(name=tag, data=tensor, step=_get_step_for_v2()) + tb_summary.scalar(name=tag, data=tensor, step=_get_step_for_v2()) # Return an empty Tensor, which will be acceptable as an input to the # `tf.compat.v1.summary.merge()` API. return _constant_op.constant(b'') @@ -235,9 +234,8 @@ def image(name, tensor, max_outputs=3, collections=None, family=None): if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import image as image_v2 # pylint: disable=g-import-not-at-top with _compat_summary_scope(name, family) as tag: - image_v2( + tb_summary.image( name=tag, data=tensor, step=_get_step_for_v2(), @@ -330,9 +328,8 @@ def histogram(name, values, collections=None, family=None): if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import histogram as histogram_v2 # pylint: disable=g-import-not-at-top with _compat_summary_scope(name, family) as tag: - histogram_v2(name=tag, data=values, step=_get_step_for_v2()) + tb_summary.histogram(name=tag, data=values, step=_get_step_for_v2()) # Return an empty Tensor, which will be acceptable as an input to the # `tf.compat.v1.summary.merge()` API. return _constant_op.constant(b'') @@ -440,12 +437,11 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None, if _should_invoke_v2_op(): # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import audio as audio_v2 # pylint: disable=g-import-not-at-top if tensor.shape.rank == 2: # TF2 op requires 3-D tensor, add the `channels` dimension. tensor = _array_ops.expand_dims_v2(tensor, axis=2) with _compat_summary_scope(name, family) as tag: - audio_v2( + tb_summary.audio( name=tag, data=tensor, sample_rate=sample_rate, @@ -540,8 +536,7 @@ def text(name, tensor, collections=None): return _constant_op.constant('') # Defer the import to happen inside the symbol to prevent breakage due to # missing dependency. - from tensorboard.summary.v2 import text as text_v2 # pylint: disable=g-import-not-at-top - text_v2(name=name, data=tensor, step=_get_step_for_v2()) + tb_summary.text(name=name, data=tensor, step=_get_step_for_v2()) # Return an empty Tensor, which will be acceptable as an input to the # `tf.compat.v1.summary.merge()` API. return _constant_op.constant(b'') diff --git a/tensorflow/python/summary/summary_v2_test.py b/tensorflow/python/summary/summary_v2_test.py index d6454b46893f05..6e3721b311f209 100644 --- a/tensorflow/python/summary/summary_v2_test.py +++ b/tensorflow/python/summary/summary_v2_test.py @@ -43,7 +43,9 @@ def test_scalar_summary_v2__w_writer(self): # Returns empty string. self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) - mock_scalar_v2.assert_called_once_with('float', data=i, step=1) + mock_scalar_v2.assert_called_once_with( + name='float', data=i, step=1, description=test.mock.ANY + ) @test_util.run_v2_only def test_scalar_summary_v2__wo_writer(self): @@ -79,7 +81,11 @@ def test_scalar_summary_v2__family(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_scalar_v2.assert_called_once_with( - 'otter/otter/float', data=constant_op.constant(2.5), step=1) + name='otter/otter/float', + data=constant_op.constant(2.5), + step=1, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_scalar_summary_v2__family_w_outer_scope(self): @@ -95,7 +101,11 @@ def test_scalar_summary_v2__family_w_outer_scope(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_scalar_v2.assert_called_once_with( - 'crabnet/sea/crabnet/float', data=constant_op.constant(3.5), step=1) + name='crabnet/sea/crabnet/float', + data=constant_op.constant(3.5), + step=1, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_scalar_summary_v2__v1_set_step(self): @@ -111,7 +121,9 @@ def test_scalar_summary_v2__v1_set_step(self): # Returns empty string. self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) - mock_scalar_v2.assert_called_once_with('float', data=i, step=1024) + mock_scalar_v2.assert_called_once_with( + name='float', data=i, step=1024, description=test.mock.ANY + ) @test_util.run_v2_only def test_image_summary_v2(self): @@ -127,7 +139,12 @@ def test_image_summary_v2(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_image_v2.assert_called_once_with( - 'family/outer/family/image', data=i, step=2, max_outputs=3) + name='family/outer/family/image', + data=i, + step=2, + max_outputs=3, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_histogram_summary_v2(self): @@ -142,7 +159,12 @@ def test_histogram_summary_v2(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_histogram_v2.assert_called_once_with( - 'family/family/histogram', data=i, step=3) + name='family/family/histogram', + data=i, + step=3, + buckets=test.mock.ANY, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_audio_summary_v2(self): @@ -158,7 +180,14 @@ def test_audio_summary_v2(self): self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) mock_audio_v2.assert_called_once_with( - 'dolphin/wave', data=i, sample_rate=0.2, step=10, max_outputs=3) + name='dolphin/wave', + data=i, + sample_rate=0.2, + step=10, + max_outputs=3, + encoding=test.mock.ANY, + description=test.mock.ANY, + ) @test_util.run_v2_only def test_audio_summary_v2__2d_tensor(self): @@ -175,7 +204,14 @@ def test_audio_summary_v2__2d_tensor(self): self.assertEqual(tensor.dtype, dtypes.string) mock_audio_v2.assert_called_once_with( - 'wave', data=test.mock.ANY, sample_rate=0.2, step=11, max_outputs=3) + name='wave', + data=test.mock.ANY, + sample_rate=0.2, + step=11, + max_outputs=3, + encoding=test.mock.ANY, + description=test.mock.ANY, + ) input_3d = array_ops.ones((5, 3, 1)) # 3-D input tensor self.assertAllEqual(mock_audio_v2.call_args[1]['data'], input_3d) @@ -191,7 +227,9 @@ def test_text_summary_v2(self): # Returns empty string. self.assertEqual(tensor.numpy(), b'') self.assertEqual(tensor.dtype, dtypes.string) - mock_text_v2.assert_called_once_with('text', data=i, step=22) + mock_text_v2.assert_called_once_with( + name='text', data=i, step=22, description=test.mock.ANY + ) if __name__ == '__main__': diff --git a/tensorflow/python/summary/tb_summary.py b/tensorflow/python/summary/tb_summary.py new file mode 100644 index 00000000000000..682ca5a2b7e1dd --- /dev/null +++ b/tensorflow/python/summary/tb_summary.py @@ -0,0 +1,374 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Re-exports the APIs of TF2 summary that live in TensorBoard.""" + +from tensorflow.python.util.tf_export import tf_export + +_TENSORBOARD_NOT_INSTALLED_ERROR = ( + "TensorBoard is not installed, missing implementation for" +) + + +class TBNotInstalledError(Exception): + + def __init__(self, summary_api): + self.error_message = f"{_TENSORBOARD_NOT_INSTALLED_ERROR} {summary_api}" + super().__init__(self.error_message) + + +@tf_export("summary.audio", v1=[]) +def audio( + name, + data, + sample_rate, + step=None, + max_outputs=3, + encoding=None, + description=None, +): + """Write an audio summary. + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A `Tensor` representing audio data with shape `[k, t, c]`, where `k` + is the number of audio clips, `t` is the number of frames, and `c` is the + number of channels. Elements should be floating-point values in `[-1.0, + 1.0]`. Any of the dimensions may be statically unknown (i.e., `None`). + sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the sample + rate, in Hz. Must be positive. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this many + audio clips will be emitted at each step. When more than `max_outputs` + many clips are provided, the first `max_outputs` many clips will be used + and the rest silently discarded. + encoding: Optional constant `str` for the desired encoding. Only "wav" is + currently supported, but this is not guaranteed to remain the default, so + if you want "wav" in particular, set this explicitly. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import audio as audio_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.audio") from exc + return audio_v2( + name=name, + data=data, + sample_rate=sample_rate, + step=step, + max_outputs=max_outputs, + encoding=encoding, + description=description, + ) + + +@tf_export("summary.histogram", v1=[]) +def histogram(name, data, step=None, buckets=None, description=None): + """Write a histogram summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. + + Writes a histogram to the current default summary writer, for later analysis + in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written + using this API will appear in both places). Like `tf.summary.scalar` points, + each histogram is associated with a `step` and a `name`. All the histograms + with the same `name` constitute a time series of histograms. + + The histogram is calculated over all the elements of the given `Tensor` + without regard to its shape or rank. + + This example writes 2 histograms: + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0) + tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0) + ``` + + A common use case is to examine the changing activation patterns (or lack + thereof) at specific layers in a neural network, over time. + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + for step in range(100): + # Generate fake "activations". + activations = [ + tf.random.normal([1000], mean=step, stddev=1), + tf.random.normal([1000], mean=step, stddev=10), + tf.random.normal([1000], mean=step, stddev=100), + ] + + tf.summary.histogram("layer1/activate", activations[0], step=step) + tf.summary.histogram("layer2/activate", activations[1], step=step) + tf.summary.histogram("layer3/activate", activations[2], step=step) + ``` + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A `Tensor` of any shape. The histogram is computed over its elements, + which must be castable to `float64`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + buckets: Optional positive `int`. The output will have this many buckets, + except in two edge cases. If there is no data, then there are no buckets. + If there is data but all points have the same value, then all buckets' + left and right endpoints are the same and only the last bucket has nonzero + count. Defaults to 30 if not specified. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import histogram as histogram_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.histogram") from exc + return histogram_v2( + name=name, data=data, step=step, buckets=buckets, description=description + ) + + +@tf_export("summary.image", v1=[]) +def image(name, data, step=None, max_outputs=3, description=None): + """Write an image summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. + + Writes a collection of images to the current default summary writer. Data + appears in TensorBoard's 'Images' dashboard. Like `tf.summary.scalar` points, + each collection of images is associated with a `step` and a `name`. All the + image collections with the same `name` constitute a time series of image + collections. + + This example writes 2 random grayscale images: + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + image1 = tf.random.uniform(shape=[8, 8, 1]) + image2 = tf.random.uniform(shape=[8, 8, 1]) + tf.summary.image("grayscale_noise", [image1, image2], step=0) + ``` + + To avoid clipping, data should be converted to one of the following: + + - floating point values in the range [0,1], or + - uint8 values in the range [0,255] + + ```python + # Convert the original dtype=int32 `Tensor` into `dtype=float64`. + rgb_image_float = tf.constant([ + [[1000, 0, 0], [0, 500, 1000]], + ]) / 1000 + tf.summary.image("picture", [rgb_image_float], step=0) + + # Convert original dtype=uint8 `Tensor` into proper range. + rgb_image_uint8 = tf.constant([ + [[1, 1, 0], [0, 0, 1]], + ], dtype=tf.uint8) * 255 + tf.summary.image("picture", [rgb_image_uint8], step=1) + ``` + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A `Tensor` representing pixel data with shape `[k, h, w, c]`, where + `k` is the number of images, `h` and `w` are the height and width of the + images, and `c` is the number of channels, which should be 1, 2, 3, or 4 + (grayscale, grayscale with alpha, RGB, RGBA). Any of the dimensions may be + statically unknown (i.e., `None`). Floating point data will be clipped to + the range [0,1]. Other data types will be clipped into an allowed range + for safe casting to uint8, using `tf.image.convert_image_dtype`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this many + images will be emitted at each step. When more than `max_outputs` many + images are provided, the first `max_outputs` many images will be used and + the rest silently discarded. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import image as image_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.image") from exc + return image_v2( + name=name, + data=data, + step=step, + max_outputs=max_outputs, + description=description, + ) + + +@tf_export("summary.scalar", v1=[]) +def scalar(name, data, step=None, description=None): + """Write a scalar summary. + + See also `tf.summary.image`, `tf.summary.histogram`, + `tf.summary.SummaryWriter`. + + Writes simple numeric values for later analysis in TensorBoard. Writes go to + the current default summary writer. Each summary point is associated with an + integral `step` value. This enables the incremental logging of time series + data. A common usage of this API is to log loss during training to produce + a loss curve. + + For example: + + ```python + test_summary_writer = tf.summary.create_file_writer('test/logdir') + with test_summary_writer.as_default(): + tf.summary.scalar('loss', 0.345, step=1) + tf.summary.scalar('loss', 0.234, step=2) + tf.summary.scalar('loss', 0.123, step=3) + ``` + + Multiple independent time series may be logged by giving each series a unique + `name` value. + + See [Get started with + TensorBoard](https://www.tensorflow.org/tensorboard/get_started) + for more examples of effective usage of `tf.summary.scalar`. + + In general, this API expects that data points are logged with a monotonically + increasing step value. Duplicate points for a single step or points logged out + of order by step are not guaranteed to display as desired in TensorBoard. + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A real numeric scalar value, convertible to a `float32` Tensor. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was written because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import scalar as scalar_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.scalar") from exc + return scalar_v2(name=name, data=data, step=step, description=description) + + +@tf_export("summary.text", v1=[]) +def text(name, data, step=None, description=None): + r"""Write a text summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`, `tf.summary.image`. + + Writes text Tensor values for later visualization and analysis in TensorBoard. + Writes go to the current default summary writer. Like `tf.summary.scalar` + points, text points are each associated with a `step` and a `name`. + All the points with the same `name` constitute a time series of text values. + + For Example: + ```python + test_summary_writer = tf.summary.create_file_writer('test/logdir') + with test_summary_writer.as_default(): + tf.summary.text('first_text', 'hello world!', step=0) + tf.summary.text('first_text', 'nice to meet you!', step=1) + ``` + + The text summary can also contain Markdown, and TensorBoard will render the + text + as such. + + ```python + with test_summary_writer.as_default(): + text_data = ''' + | *hello* | *there* | + |---------|---------| + | this | is | + | a | table | + ''' + text_data = '\n'.join(l.strip() for l in text_data.splitlines()) + tf.summary.text('markdown_text', text_data, step=0) + ``` + + Since text is Tensor valued, each text point may be a Tensor of string values. + rank-1 and rank-2 Tensors are rendered as tables in TensorBoard. For higher + ranked + Tensors, you'll see just a 2D slice of the data. To avoid this, reshape the + Tensor + to at most rank-2 prior to passing it to this function. + + Demo notebook at + ["Displaying text data in + TensorBoard"](https://www.tensorflow.org/tensorboard/text_summaries). + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A UTF-8 string Tensor value. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + description: Optional long-form description for this summary, as a constant + `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + try: + from tensorboard.summary.v2 import text as text_v2 # pylint: disable=g-import-not-at-top, g-importing-member + except ImportError as exc: + raise TBNotInstalledError("tf.summary.text") from exc + return text_v2(name=name, data=data, step=step, description=description) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index cae983b25dfb02..21fecca23371bc 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -858,6 +858,23 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // TODO(b/309152522): Remove the switch once it works on Windows. #if !IS_OSS pybind11_protobuf::ImportNativeProtoCasters(); + m.def( + "TFE_ContextAddFunctionDefNoSerialization", + [](py::handle& ctx, tensorflow::FunctionDef function_def) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + // Annotate eager runtime construction context to the given + // `function_def` as an attribute. + tensorflow::AttrValue value; + SetAttrValue("kEagerRuntime", &value); + (*function_def.mutable_attr())["_construction_context"] = value; + status->status = tensorflow::unwrap(tensorflow::InputTFE_Context(ctx)) + ->AddFunctionDef(function_def); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return; + }, + pybind11::arg("ctx"), pybind11::arg("function_def")); + m.def("TFE_ContextGetFunctionDefNoSerialization", [](py::handle& ctx, const char* function_name) -> tensorflow::FunctionDef { @@ -885,6 +902,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) { LOG(FATAL) << "This function cannot be called."; return -1; }); + m.def("TFE_ContextAddFunctionDefNoSerialization", + // Opensource fails whenever a protobuf is used as argument. The + // disrepency in the type is to make opensource tests pass. + [](py::handle& ctx, int function_def) { + LOG(FATAL) << "This function cannot be called."; + return -1; + }); + #endif m.def("TFE_ContextGetGraphDebugInfo", [](py::handle& ctx, const char* function_name, TF_Buffer& buf) { diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl index 7ea32b6cb51e57..763a5f241581b7 100644 --- a/tensorflow/python/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -133,7 +133,7 @@ def gen_api_init_files( srcs_version = "PY3", visibility = ["//visibility:public"], deps = package_deps + [ - "//tensorflow/python/util:tf_decorator", + "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:module_wrapper", "//tensorflow/python/tools/api/generator:doc_srcs", diff --git a/tensorflow/python/tools/api/generator2/generate_api.bzl b/tensorflow/python/tools/api/generator2/generate_api.bzl index 64e9b96276eebe..c2a96438576d22 100644 --- a/tensorflow/python/tools/api/generator2/generate_api.bzl +++ b/tensorflow/python/tools/api/generator2/generate_api.bzl @@ -1,5 +1,6 @@ """Rules to generate the TensorFlow public API from annotated files.""" +# Placeholder: load PyInfo load("@bazel_skylib//lib:paths.bzl", "paths") load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES") load(":apis.bzl", _APIS = "APIS") diff --git a/tensorflow/python/tools/print_selective_registration_header.py b/tensorflow/python/tools/print_selective_registration_header.py index 8ae04c137e4eb4..6809ea62f51513 100644 --- a/tensorflow/python/tools/print_selective_registration_header.py +++ b/tensorflow/python/tools/print_selective_registration_header.py @@ -32,10 +32,17 @@ """ import argparse +import contextlib import sys from absl import app -from tensorflow.python.tools import selective_registration_header_lib + +# Import statement prints "Using TensorFlow backend" which gets piped to +# ops_to_register.h. Avoid this printing import statement to /dev/null +with open('/dev/null', 'w') as f, contextlib.redirect_stdout(f): + # pylint: disable=g-import-not-at-top + from tensorflow.python.tools import selective_registration_header_lib + # pylint: enable FLAGS = None diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index d3d06e05502fdc..751078f59d56b9 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -99,7 +99,6 @@ tpu_py_strict_test( disable_mlir_bridge = False, deps = [ ":async_checkpoint", - ":tpu_estimator", ":tpu_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/compat:v2_compat", @@ -169,42 +168,6 @@ py_strict_library( ], ) -py_strict_library( - name = "tpu_estimator", - srcs = [ - "error_handling.py", - "tpu_config.py", - "tpu_context.py", - "tpu_estimator.py", - "util.py", - ], - srcs_version = "PY3", - deps = [ - ":async_checkpoint", - ":feature_column", - ":feature_column_v2", - ":functional", - ":preempted_hook_py", - ":tpu_embedding", - ":tpu_lib", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/client:session", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:util", - "//tensorflow/python/framework:for_generated_wrappers", - "//tensorflow/python/framework:function", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:control_flow_ops", - "//tensorflow/python/ops:init_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:state_ops", - "//tensorflow/python/ops:summary_ops_v2", - "//tensorflow/python/ops:variable_scope", - "//tensorflow/python/ops:variables", - "//tensorflow/python/training", - ], -) - py_strict_library( name = "functional", srcs = ["functional.py"], diff --git a/tensorflow/python/tpu/async_checkpoint_test.py b/tensorflow/python/tpu/async_checkpoint_test.py index 070eff0e20c60e..3601c5fad6cc0d 100644 --- a/tensorflow/python/tpu/async_checkpoint_test.py +++ b/tensorflow/python/tpu/async_checkpoint_test.py @@ -33,13 +33,13 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.tpu import async_checkpoint -from tensorflow.python.tpu import tpu_config -from tensorflow.python.tpu import tpu_estimator from tensorflow.python.tpu import tpu_optimizer from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import training from tensorflow_estimator.python.estimator import estimator as estimator_lib from tensorflow_estimator.python.estimator import model_fn as model_fn_lib +from tensorflow_estimator.python.estimator.tpu import tpu_config +from tensorflow_estimator.python.estimator.tpu import tpu_estimator FLAGS = flags.FLAGS flags.DEFINE_string('tpu', '', 'TPU to use in this test.') diff --git a/tensorflow/python/tpu/error_handling.py b/tensorflow/python/tpu/error_handling.py deleted file mode 100644 index 1e6660af511bc1..00000000000000 --- a/tensorflow/python/tpu/error_handling.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.error_handling import * -# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/tpu/tpu_context.py b/tensorflow/python/tpu/tpu_context.py deleted file mode 100644 index d1f3ee55723df3..00000000000000 --- a/tensorflow/python/tpu/tpu_context.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.tpu_context import * -# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/tpu/tpu_embedding_v3_utils.py b/tensorflow/python/tpu/tpu_embedding_v3_utils.py index ed30d9947c8842..276731051be54f 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v3_utils.py @@ -73,7 +73,8 @@ def unshuffle_from_sc_to_cpu( shards = shards_t[:, offset_in_shard : offset_in_shard + size_in_shard, :] # This table's shards were rotated by `shard_rotation`, so we need to rotate # the same amount in opposite direction - shards = manip_ops.roll(shards, -shard_rotation, axis=0) + if shard_rotation: + shards = manip_ops.roll(shards, -shard_rotation, axis=0) # Re-arrange (transpose and reshape) the shards to get the queried embedding # table. intermediate_tensor = array_ops.transpose(shards, (1, 0, 2)) @@ -169,6 +170,12 @@ def __init__(self, stacked_layouts, table_to_config): shape=variable_shape, dtype=dtypes.float32, ) + # TODO(b/312743130): This is a workaround. During checkpoint restoration + # optimizer expects the trackable to provide a `_unique_id` or equivalent. + # Remove this when the bug is fixed. + @property + def _unique_id(self): + return self.vars[self._stacked_layouts[0].table_name]._unique_id def _serialize_to_tensors(self) -> Any: return { diff --git a/tensorflow/python/tpu/tpu_estimator.py b/tensorflow/python/tpu/tpu_estimator.py deleted file mode 100644 index f28db848e56252..00000000000000 --- a/tensorflow/python/tpu/tpu_estimator.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import,redefined-builtin -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import * -# used by tests -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _clone_export_output_with_tensors -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _create_global_step -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _export_output_to_tensors -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _get_scaffold -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _Inputs -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ENQUEUE_OPS -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ESTIMATOR -from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_TRAIN_OP -# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/python/tpu/util.py b/tensorflow/python/tpu/util.py deleted file mode 100644 index c5b8964b20a6e2..00000000000000 --- a/tensorflow/python/tpu/util.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stub file to maintain backwards compatibility.""" - -# pylint: disable=wildcard-import,unused-import -from tensorflow_estimator.python.estimator.tpu.util import * -# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/trackable/BUILD b/tensorflow/python/trackable/BUILD index 67d4811402c864..a0a315b85fd475 100644 --- a/tensorflow/python/trackable/BUILD +++ b/tensorflow/python/trackable/BUILD @@ -51,7 +51,8 @@ py_strict_library( "//tensorflow/python/framework:ops", "//tensorflow/python/ops:control_flow_ops_gen", "//tensorflow/python/training/saving:saveable_object", - "//tensorflow/python/util:tf_decorator", + "//tensorflow/python/util:tf_contextlib", + "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", ], ) @@ -188,7 +189,7 @@ py_strict_library( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", - "//tensorflow/python/util:tf_decorator", + "//tensorflow/python/util:tf_contextlib", "//tensorflow/python/util:tf_export", ], ) diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index a041d50f4b61cb..847e93392b2c67 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -75,7 +75,6 @@ py_strict_library( visibility = [ "//tensorflow:internal", "//tensorflow_minigo:__subpackages__", - "//tensorflow_model_optimization:__subpackages__", "//tensorflow_models:__subpackages__", "//third_party/cloud_tpu/convergence_tools:__subpackages__", "//third_party/mlperf:__subpackages__", @@ -229,7 +228,6 @@ py_strict_library( srcs_version = "PY3", visibility = [ "//tensorflow:internal", - "//tensorflow_estimator/python/estimator:__pkg__", "//third_party/py/tf_slim/training:__pkg__", ], deps = [ @@ -340,7 +338,6 @@ py_strict_library( srcs_version = "PY3", visibility = [ "//tensorflow:internal", - "//tensorflow_model_optimization/python/core/quantization/keras:__pkg__", "//third_party/py/tf_slim/layers:__pkg__", ], deps = [ diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 92ec7cff402129..fd4243e4d3021e 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1029,9 +1029,10 @@ def _RecordLastCheckpoint(self, latest_save_path): if not self.saver_def.max_to_keep: return # Remove first from list if the same name was used before. - for p in self._last_checkpoints: + for p in self._last_checkpoints[:]: if latest_save_path == self._CheckpointFilename(p): self._last_checkpoints.remove(p) + # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD index c04dc039153fa1..799ca38c72981a 100644 --- a/tensorflow/python/types/BUILD +++ b/tensorflow/python/types/BUILD @@ -21,7 +21,6 @@ pytype_strict_library( deps = [ ":doc_typealias", "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python/util:_pywrap_utils", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", "@pypi_typing_extensions//:pkg", diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index 16c9d24593e2ab..534211fd9d29ba 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -26,7 +26,6 @@ from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import, g-bad-import-order -from tensorflow.python.util import _pywrap_utils from tensorflow.python.util.tf_export import tf_export # pylint:disable=g-import-not-at-top @@ -385,10 +384,6 @@ def __tf_tensor__(self, dtype=None, name=None): pass -_pywrap_utils.RegisterType("TensorProtocol", TensorProtocol) -_pywrap_utils.RegisterType("CoreTypeValue", Value) - - # TODO(rahulkamat): Add missing types that are convertible to Tensor. TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, bytes, complex, tuple, list, np.ndarray, np.generic] diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index 7d2d9ef5398809..d8cf6ed3e5c7b1 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -43,8 +43,10 @@ package( py_strict_library( name = "core", deps = [ - ":tf_decorator", + ":tf_contextlib", + ":tf_decorator_py", ":tf_export", + ":tf_inspect", ":tf_stack", ], ) @@ -361,7 +363,7 @@ pytype_strict_library( ], ) -py_strict_library( +pytype_strict_library( name = "tf_contextlib", srcs = ["tf_contextlib.py"], compatible_with = get_compatible_with_portable(), @@ -413,29 +415,6 @@ tf_py_strict_test( ], ) -# Leaf library: may not depend on anything else inside TensorFlow. -# TODO(mdan): Move this utility outside of TF. -py_strict_library( - name = "tf_decorator", - compatible_with = get_compatible_with_portable(), - deprecation = "This target has been split. Depend on the sub-targets instead.", - srcs_version = "PY3", - visibility = [ - "//tensorflow:__subpackages__", - # TODO(mdan): Remove these dependencies. - "//third_party/py/tf_slim:__subpackages__", - "//learning/deepmind/research/language/translation/lm:__subpackages__", - "//learning/brain/analytics:__subpackages__", - "//tensorflow:__pkg__", - "//third_party/py/tensorflow_core:__subpackages__", - ], - deps = [ - ":tf_contextlib", - ":tf_decorator_py", - ":tf_inspect", - ], -) - py_strict_library( name = "tf_stack", srcs = ["tf_stack.py"], @@ -445,7 +424,6 @@ py_strict_library( deps = [ ":_tf_stack", "//tensorflow/core:protos_all_py", - "@six_archive//:six", ], ) @@ -770,7 +748,6 @@ py_strict_library( # library. It isn't possible to add these test dependencies via tensorflow.bzl's # py_test because not all tensorflow tests use tensorflow.bzl's py_test. "//tensorflow/python:global_test_configuration", - "@six_archive//:six", "@pypi_wrapt//:pkg", "//tensorflow/python:pywrap_tensorflow", ":_pywrap_utils", @@ -788,7 +765,6 @@ py_strict_library( # library. It isn't possible to add these test dependencies via tensorflow.bzl's # py_test because not all tensorflow tests use tensorflow.bzl's py_test. "//tensorflow/python:global_test_configuration", - "@six_archive//:six", ], ) @@ -887,7 +863,6 @@ py_strict_library( "//tensorflow/python:global_test_configuration", ":tf_export", "//third_party/py/numpy", - "@six_archive//:six", ], ) @@ -997,8 +972,6 @@ py_strict_library( # library. It isn't possible to add these test dependencies via tensorflow.bzl's # py_test because not all tensorflow tests use tensorflow.bzl's py_test. "//tensorflow/python:global_test_configuration", - ":tf_decorator", - "@six_archive//:six", ], ) @@ -1067,14 +1040,12 @@ py_strict_library( visibility = util_subpackage_visibility, deps = [ ":__init__", - ":compat", ":nest_util", # global_test_configuration is added here because all major tests depend on this # library. It isn't possible to add these test dependencies via tensorflow.bzl's # py_test because not all tensorflow tests use tensorflow.bzl's py_test. "//tensorflow/python:global_test_configuration", ":tf_export", - "@pypi_wrapt//:pkg", ":_pywrap_utils", ":_pywrap_nest", ], diff --git a/tensorflow/python/util/_pywrap_utils.pyi b/tensorflow/python/util/_pywrap_utils.pyi index c8e51ec4fa961f..f5c7af0c990e0a 100644 --- a/tensorflow/python/util/_pywrap_utils.pyi +++ b/tensorflow/python/util/_pywrap_utils.pyi @@ -32,5 +32,4 @@ def IsTensor(arg0: object) -> bool: ... def IsTypeSpec(arg0: object) -> bool: ... def IsVariable(arg0: object) -> bool: ... def RegisterPyObject(arg0: object, arg1: object) -> object: ... -def RegisterType(arg0: object, arg1: object) -> object: ... def SameNamedtuples(arg0: object, arg1: object) -> object: ... diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 0d3c1a2b3c6582..7a4659e0f62251 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -45,20 +45,14 @@ API docstring: tensorflow.compat """ +import codecs +import collections.abc as collections_abc # pylint: disable=unused-import import numbers as _numbers import numpy as _np -import six as _six -import codecs from tensorflow.python.util.tf_export import tf_export -try: - # This import only works on python 3.3 and above. - import collections.abc as collections_abc # pylint: disable=unused-import -except ImportError: - import collections as collections_abc # pylint: disable=unused-import - def as_bytes(bytes_or_text, encoding='utf-8'): """Converts `bytearray`, `bytes`, or unicode python input types to `bytes`. @@ -79,7 +73,7 @@ def as_bytes(bytes_or_text, encoding='utf-8'): encoding = codecs.lookup(encoding).name if isinstance(bytes_or_text, bytearray): return bytes(bytes_or_text) - elif isinstance(bytes_or_text, _six.text_type): + elif isinstance(bytes_or_text, str): return bytes_or_text.encode(encoding) elif isinstance(bytes_or_text, bytes): return bytes_or_text @@ -106,7 +100,7 @@ def as_text(bytes_or_text, encoding='utf-8'): """ # Validate encoding, a LookupError will be raised if invalid. encoding = codecs.lookup(encoding).name - if isinstance(bytes_or_text, _six.text_type): + if isinstance(bytes_or_text, str): return bytes_or_text elif isinstance(bytes_or_text, bytes): return bytes_or_text.decode(encoding) @@ -212,6 +206,6 @@ def path_to_bytes(path): tf_export('compat.complex_types').export_constant(__name__, 'complex_types') # Either bytes or text. -bytes_or_text_types = (bytes, _six.text_type) +bytes_or_text_types = (bytes, str) tf_export('compat.bytes_or_text_types').export_constant(__name__, 'bytes_or_text_types') diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py index fa978fe12d56ac..743a81343240c1 100644 --- a/tensorflow/python/util/function_utils.py +++ b/tensorflow/python/util/function_utils.py @@ -16,8 +16,6 @@ import functools -import six - from tensorflow.core.protobuf import config_pb2 from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -89,8 +87,10 @@ def get_func_name(func): if tf_inspect.isfunction(func): return func.__name__ elif tf_inspect.ismethod(func): - return '%s.%s' % (six.get_method_self(func).__class__.__name__, - six.get_method_function(func).__name__) + return '%s.%s' % ( + func.__self__.__class__.__name__, + func.__func__.__name__, + ) else: # Probably a class instance with __call__ return str(type(func)) else: @@ -104,13 +104,13 @@ def get_func_code(func): _, func = tf_decorator.unwrap(func) if callable(func): if tf_inspect.isfunction(func) or tf_inspect.ismethod(func): - return six.get_function_code(func) + return func.__code__ # Since the object is not a function or method, but is a callable, we will # try to access the __call__method as a function. This works with callable # classes but fails with functool.partial objects despite their __call__ # attribute. try: - return six.get_function_code(func.__call__) + return func.__call__.__code__ except AttributeError: return None else: diff --git a/tensorflow/python/util/lazy_loader.py b/tensorflow/python/util/lazy_loader.py index 717965d0123614..7d8c186677583f 100644 --- a/tensorflow/python/util/lazy_loader.py +++ b/tensorflow/python/util/lazy_loader.py @@ -106,6 +106,9 @@ def __dir__(self): module = self._load() return dir(module) + def __reduce__(self): + return importlib.import_module, (self.__name__,) + class KerasLazyLoader(LazyLoader): """LazyLoader that handles routing to different Keras version.""" diff --git a/tensorflow/python/util/lazy_loader_test.py b/tensorflow/python/util/lazy_loader_test.py index 94f258131772c1..e59ef2c888edc1 100644 --- a/tensorflow/python/util/lazy_loader_test.py +++ b/tensorflow/python/util/lazy_loader_test.py @@ -17,6 +17,7 @@ # pylint: disable=unused-import import doctest import inspect +import pickle import types from tensorflow.python.platform import test @@ -54,5 +55,16 @@ def testLazyLoaderMock(self, mock_warning): self.assertEqual(lazy_loader_module.foo, foo) +class PickleTest(test.TestCase): + + def testPickleLazyLoader(self): + name = PickleTest.__module__ # Try to pickle current module. + lazy_loader_module = lazy_loader.LazyLoader( + "lazy_loader_module", globals(), name) + restored = pickle.loads(pickle.dumps(lazy_loader_module)) + self.assertEqual(restored.__name__, name) + self.assertIsNotNone(restored.PickleTest) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index d7acf836ce5e1f..748fc3b167f5c8 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -90,12 +90,9 @@ API docstring: tensorflow.nest """ -import wrapt as _wrapt - from tensorflow.python.util import _pywrap_nest from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import nest_util -from tensorflow.python.util.compat import collections_abc as _collections_abc from tensorflow.python.util.tf_export import tf_export @@ -1315,10 +1312,3 @@ def sequence_fn(instance, args): False, sequence_fn=sequence_fn, ) - - -_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping) -_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping) -_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence) -_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView) -_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 26341624c06619..0378076cba247b 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -154,24 +154,24 @@ class UnsortedSampleAttr(object): field1 = attr.ib() field2 = attr.ib() - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassCustomProtocol(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) self.assertIsInstance(mt, CustomNestProtocol) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassIsNested(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) self.assertTrue(nest.is_nested(mt)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlatten(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) leaves = nest.flatten(mt) self.assertLen(leaves, 1) self.assertAllEqual(leaves[0], [1]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenUpToCompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -200,7 +200,7 @@ def testDataclassFlattenUpToCompatible(self): ) self.assertAllEqual(flat_path_nested_list, [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenUpToIncompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -239,7 +239,7 @@ def testDataclassFlattenUpToIncompatible(self): shallow_tree=nested_list, input_tree=mt, check_types=False ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithTuplePathsUpToCompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -271,7 +271,7 @@ def testDataclassFlattenWithTuplePathsUpToCompatible(self): ) self.assertAllEqual(flat_path_nested_list, [[(0, 0), 2]]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithTuplePathsUpToIncompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -311,7 +311,7 @@ def testDataclassFlattenWithTuplePathsUpToIncompatible(self): shallow_tree=nested_list2, input_tree=nmt, check_types=False ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenAndPack(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) leaves = nest.flatten(mt) @@ -319,7 +319,7 @@ def testDataclassFlattenAndPack(self): self.assertIsInstance(reconstructed_mt, MaskedTensor) self.assertEqual(reconstructed_mt, mt) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructure(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt_doubled = nest.map_structure(lambda x: x * 2, mt) @@ -327,7 +327,7 @@ def testDataclassMapStructure(self): self.assertEqual(mt_doubled.mask, True) self.assertAllEqual(mt_doubled.value, [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureWithPaths(self): mt = MaskedTensor(mask=False, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -360,7 +360,7 @@ def path_sum(path, *tensors): self.assertAllEqual(nmt_combined_with_path.value.value[0], "0/0") self.assertAllEqual(nmt_combined_with_path.value.value[1], [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureWithTuplePaths(self): mt = MaskedTensor(mask=False, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -395,7 +395,7 @@ def tuple_path_sum(tuple_path, *tensors): self.assertAllEqual(nmt_combined_with_path.value.value[0], (0, 0)) self.assertAllEqual(nmt_combined_with_path.value.value[1], [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureUpTo(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -431,7 +431,7 @@ def sum_tensors(*tensors): self.assertEqual(nmt_combined_with_path.value.mask, True) self.assertAllEqual(nmt_combined_with_path.value.value, [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureWithTuplePathsUoTo(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -470,7 +470,7 @@ def tuple_path_sum(tuple_path, *tensors): self.assertAllEqual(nmt_combined_with_path.value.value[0], (0, 0)) self.assertAllEqual(nmt_combined_with_path.value.value[1], [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassIsNested(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) self.assertTrue(nest.is_nested(mt)) @@ -480,7 +480,7 @@ def testNestedDataclassIsNested(self): ) self.assertTrue(nest.is_nested(nmt)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassAssertShallowStructure(self): # These assertions are expected to pass: two dataclasses with the same # component size are considered to have the same shallow structure. @@ -535,7 +535,7 @@ def testDataclassAssertShallowStructure(self): shallow_tree=nmt, input_tree=mt, check_types=False ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassGetTraverseShallowStructure(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -568,7 +568,7 @@ def testDataclassGetTraverseShallowStructure(self): self.assertEqual(traverse_result3, False) nest.assert_shallow_structure(traverse_result3, nmt) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassFlatten(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -577,7 +577,7 @@ def testNestedDataclassFlatten(self): self.assertLen(leaves, 1) self.assertAllEqual(leaves[0], [1]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassFlattenAndPack(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -587,7 +587,7 @@ def testNestedDataclassFlattenAndPack(self): self.assertIsInstance(reconstructed_mt, NestedMaskedTensor) self.assertEqual(reconstructed_mt, nmt) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassMapStructure(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -602,7 +602,7 @@ def testNestedDataclassMapStructure(self): self.assertEqual(mt_doubled.value.mask, expected.value.mask) self.assertAllEqual(mt_doubled.value.value, expected.value.value) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassYieldFlatPaths(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt_flat_paths = list(nest.yield_flat_paths(mt)) @@ -626,7 +626,7 @@ def testDataclassYieldFlatPaths(self): ], ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithStringPaths(self): sep = "/" mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -650,7 +650,7 @@ def testDataclassFlattenWithStringPaths(self): self.assertEqual(dict_mt_nmt_flat_paths[1][0], "nmt/0/0") self.assertAllEqual(dict_mt_nmt_flat_paths[1][1], [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithTuplePaths(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt_flat_paths = nest.flatten_with_tuple_paths(mt) @@ -671,7 +671,7 @@ def testDataclassFlattenWithTuplePaths(self): self.assertEqual(dict_mt_nmt_flat_paths[1][0], ("nmt", 0, 0)) self.assertAllEqual(dict_mt_nmt_flat_paths[1][1], [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassListToTuple(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( @@ -690,7 +690,7 @@ def testDataclassListToTuple(self): ) nest.assert_same_structure(results, expected) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAttrsFlattenAndPack(self): if attr is None: self.skipTest("attr module is unavailable.") @@ -715,7 +715,7 @@ def testAttrsFlattenAndPack(self): {"values": [(1, 2), [3, 4], 5]}, {"values": [PointXY(1, 2), 3, 4]}, ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAttrsMapStructure(self, values): if attr is None: self.skipTest("attr module is unavailable.") @@ -724,7 +724,7 @@ def testAttrsMapStructure(self, values): new_structure = nest.map_structure(lambda x: x, structure) self.assertEqual(structure, new_structure) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] @@ -761,7 +761,7 @@ def testFlattenAndPack(self): @parameterized.parameters({"mapping_type": collections.OrderedDict}, {"mapping_type": _CustomMapping}) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenDictOrder(self, mapping_type): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) @@ -787,7 +787,7 @@ def testPackDictOrder(self, mapping_type): custom_reconstruction) self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenAndPackMappingViews(self): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) @@ -806,7 +806,7 @@ def testFlattenAndPackMappingViews(self): Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenAndPack_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. mess = [ @@ -889,7 +889,7 @@ def testPackSequenceAs_CompositeTensor(self): ValueError, "Structure had 2 atoms, but flat_sequence had 1 items."): nest.pack_sequence_as(val, [val], expand_composites=True) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testIsNested(self): self.assertFalse(nest.is_nested("1234")) self.assertTrue(nest.is_nested([1, 3, [4, 5]])) @@ -942,7 +942,7 @@ def testFlattenDictItems(self, mapping_type): class SameNamedType1(SameNameab): pass - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAssertSameStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) @@ -1053,7 +1053,7 @@ def testHeterogeneousComparison(self): nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) @@ -1129,7 +1129,7 @@ def testMapStructure(self): ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMapStructureWithStrings(self): inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) inp_b = NestTest.ABTuple(a=2, b=(1, 3)) diff --git a/tensorflow/python/util/nest_util.py b/tensorflow/python/util/nest_util.py index f40cc2d3642341..c53042f7dc11ab 100644 --- a/tensorflow/python/util/nest_util.py +++ b/tensorflow/python/util/nest_util.py @@ -27,7 +27,6 @@ import collections as _collections import enum -import six as _six import wrapt as _wrapt from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import @@ -236,7 +235,7 @@ def sequence_like(instance, args): # Pack a CompositeTensor's components according to a TypeSpec. assert len(args) == 1 return instance._from_components(args[0]) # pylint: disable=protected-access - elif isinstance(instance, _six.moves.range): + elif isinstance(instance, range): return sequence_like(list(instance), args) elif isinstance(instance, _wrapt.ObjectProxy): # For object proxies, first create the underlying type and then re-wrap it diff --git a/tensorflow/python/util/protobuf/BUILD b/tensorflow/python/util/protobuf/BUILD index 85c44a9eedccfa..585c51706e271f 100644 --- a/tensorflow/python/util/protobuf/BUILD +++ b/tensorflow/python/util/protobuf/BUILD @@ -48,7 +48,6 @@ tf_py_strict_test( ":compare_test_proto_py", ":protobuf", "//tensorflow/python/platform:test", - "@six_archive//:six", ], ) @@ -88,7 +87,6 @@ py_strict_library( # py_test because not all tensorflow tests use tensorflow.bzl's py_test. "//tensorflow/python:global_test_configuration", "@com_google_protobuf//:protobuf_python", - "@six_archive//:six", "//tensorflow/python/util:compat", ], ) diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py index 44a9bfd15b3b75..dbc61ae28f4674 100644 --- a/tensorflow/python/util/protobuf/compare.py +++ b/tensorflow/python/util/protobuf/compare.py @@ -58,12 +58,10 @@ def testXXX(self): self.assertProtoEqual(a, b) """ +import collections.abc as collections_abc import difflib import math -from ..compat import collections_abc -import six - from google.protobuf import descriptor from google.protobuf import descriptor_pool from google.protobuf import message @@ -147,7 +145,7 @@ def checkFloatEqAndReplace(self, expected, actual, relative_tolerance): # pylin == descriptor.FieldDescriptor.TYPE_MESSAGE ): for e_v, a_v in zip( - six.itervalues(expected_values), six.itervalues(actual_values) + iter(expected_values.values()), iter(actual_values.values()) ): checkFloatEqAndReplace( self, @@ -191,7 +189,7 @@ def assertProtoEqual( comparisons are done using the relative tolerance provided. """ pool = descriptor_pool.Default() - if isinstance(a, six.string_types): + if isinstance(a, str): a = text_format.Parse(a, b.__class__(), descriptor_pool=pool) for pb in a, b: @@ -281,7 +279,7 @@ def NormalizeNumberFields(pb): # This is a map, only recurse if the values have a message type. if (desc.message_type.fields_by_number[2].type == descriptor.FieldDescriptor.TYPE_MESSAGE): - for v in six.itervalues(values): + for v in iter(values.values()): NormalizeNumberFields(v) else: for v in values: @@ -296,7 +294,7 @@ def _IsMap(value): def _IsRepeatedContainer(value): - if isinstance(value, six.string_types): + if isinstance(value, str): return False try: iter(value) diff --git a/tensorflow/python/util/protobuf/compare_test.py b/tensorflow/python/util/protobuf/compare_test.py index 96484c5df87856..ef521baf2807b6 100644 --- a/tensorflow/python/util/protobuf/compare_test.py +++ b/tensorflow/python/util/protobuf/compare_test.py @@ -19,7 +19,6 @@ import sys import textwrap -import six from google.protobuf import text_format @@ -30,13 +29,7 @@ def LargePbs(*args): """Converts ASCII string Large PBs to messages.""" - pbs = [] - for arg in args: - pb = compare_test_pb2.Large() - text_format.Merge(arg, pb) - pbs.append(pb) - - return pbs + return [text_format.Merge(arg, compare_test_pb2.Large()) for arg in args] class ProtoEqTest(googletest.TestCase): @@ -267,49 +260,44 @@ class NormalizeNumbersTest(googletest.TestCase): """Tests for NormalizeNumberFields().""" def testNormalizesInts(self): - pb = compare_test_pb2.Large() - pb.int64_ = 4 + pb = compare_test_pb2.Large(int64_=4) compare.NormalizeNumberFields(pb) - self.assertTrue(isinstance(pb.int64_, six.integer_types)) + self.assertIsInstance(pb.int64_, int) pb.int64_ = 4 compare.NormalizeNumberFields(pb) - self.assertTrue(isinstance(pb.int64_, six.integer_types)) + self.assertIsInstance(pb.int64_, int) pb.int64_ = 9999999999999999 compare.NormalizeNumberFields(pb) - self.assertTrue(isinstance(pb.int64_, six.integer_types)) + self.assertIsInstance(pb.int64_, int) def testNormalizesRepeatedInts(self): - pb = compare_test_pb2.Large() - pb.int64s.extend([1, 400, 999999999999999]) + pb = compare_test_pb2.Large(int64s=[1, 400, 999999999999999]) compare.NormalizeNumberFields(pb) - self.assertTrue(isinstance(pb.int64s[0], six.integer_types)) - self.assertTrue(isinstance(pb.int64s[1], six.integer_types)) - self.assertTrue(isinstance(pb.int64s[2], six.integer_types)) + self.assertIsInstance(pb.int64s[0], int) + self.assertIsInstance(pb.int64s[1], int) + self.assertIsInstance(pb.int64s[2], int) def testNormalizesFloats(self): - pb1 = compare_test_pb2.Large() - pb1.float_ = 1.2314352351231 - pb2 = compare_test_pb2.Large() - pb2.float_ = 1.231435 + pb1 = compare_test_pb2.Large(float_=1.2314352351231) + pb2 = compare_test_pb2.Large(float_=1.231435) self.assertNotEqual(pb1.float_, pb2.float_) compare.NormalizeNumberFields(pb1) compare.NormalizeNumberFields(pb2) self.assertEqual(pb1.float_, pb2.float_) def testNormalizesRepeatedFloats(self): - pb = compare_test_pb2.Large() - pb.medium.floats.extend([0.111111111, 0.111111]) + pb = compare_test_pb2.Large( + medium=compare_test_pb2.Medium(floats=[0.111111111, 0.111111]) + ) compare.NormalizeNumberFields(pb) for value in pb.medium.floats: self.assertAlmostEqual(0.111111, value) def testNormalizesDoubles(self): - pb1 = compare_test_pb2.Large() - pb1.double_ = 1.2314352351231 - pb2 = compare_test_pb2.Large() - pb2.double_ = 1.2314352 + pb1 = compare_test_pb2.Large(double_=1.2314352351231) + pb2 = compare_test_pb2.Large(double_=1.2314352) self.assertNotEqual(pb1.double_, pb2.double_) compare.NormalizeNumberFields(pb1) compare.NormalizeNumberFields(pb2) @@ -326,7 +314,7 @@ class AssertTest(googletest.TestCase): """Tests assertProtoEqual().""" def assertProtoEqual(self, a, b, **kwargs): - if isinstance(a, six.string_types) and isinstance(b, six.string_types): + if isinstance(a, str) and isinstance(b, str): a, b = LargePbs(a, b) compare.assertProtoEqual(self, a, b, **kwargs) @@ -346,8 +334,7 @@ def assertNone(self, a, b, message, **kwargs): def testCheckInitialized(self): # neither is initialized - a = compare_test_pb2.Labeled() - a.optional = 1 + a = compare_test_pb2.Labeled(optional=1) self.assertNone(a, a, 'Initialization errors: ', check_initialized=True) self.assertAll(a, check_initialized=False) @@ -365,8 +352,7 @@ def testCheckInitialized(self): check_initialized=False) # both are initialized - a = compare_test_pb2.Labeled() - a.required = 2 + a = compare_test_pb2.Labeled(required=2) self.assertAll(a, check_initialized=True) self.assertAll(a, check_initialized=False) @@ -382,26 +368,20 @@ def testCheckInitialized(self): self.assertNone(a, b, message, check_initialized=False) def testAssertEqualWithStringArg(self): - pb = compare_test_pb2.Large() - pb.string_ = 'abc' - pb.float_ = 1.234 + pb = compare_test_pb2.Large(string_='abc', float_=1.234) compare.assertProtoEqual(self, """ string_: 'abc' float_: 1.234 """, pb) def testNormalizesNumbers(self): - pb1 = compare_test_pb2.Large() - pb1.int64_ = 4 - pb2 = compare_test_pb2.Large() - pb2.int64_ = 4 + pb1 = compare_test_pb2.Large(int64_=4) + pb2 = compare_test_pb2.Large(int64_=4) compare.assertProtoEqual(self, pb1, pb2) def testNormalizesFloat(self): - pb1 = compare_test_pb2.Large() - pb1.double_ = 4.0 - pb2 = compare_test_pb2.Large() - pb2.double_ = 4 + pb1 = compare_test_pb2.Large(double_=4.0) + pb2 = compare_test_pb2.Large(double_=4) compare.assertProtoEqual(self, pb1, pb2, normalize_numbers=True) def testLargeProtoData(self): @@ -542,9 +522,7 @@ def testRepeatedMessage(self): class MixinTests(compare.ProtoAssertions, googletest.TestCase): def testAssertEqualWithStringArg(self): - pb = compare_test_pb2.Large() - pb.string_ = 'abc' - pb.float_ = 1.234 + pb = compare_test_pb2.Large(string_='abc', float_=1.234) self.assertProtoEqual(""" string_: 'abc' float_: 1.234 diff --git a/tensorflow/python/util/tf_contextlib.py b/tensorflow/python/util/tf_contextlib.py index 06a947e26249bb..52f2c3d1c3e3fc 100644 --- a/tensorflow/python/util/tf_contextlib.py +++ b/tensorflow/python/util/tf_contextlib.py @@ -13,12 +13,19 @@ # limitations under the License. # ============================================================================== """TFDecorator-aware replacements for the contextlib module.""" +from collections.abc import Callable, Iterator import contextlib as _contextlib +from typing import ContextManager, TypeVar + from tensorflow.python.util import tf_decorator +_T = TypeVar('_T') + -def contextmanager(target): +def contextmanager( + target: Callable[..., Iterator[_T]], +) -> Callable[..., ContextManager[_T]]: """A tf_decorator-aware wrapper for `contextlib.contextmanager`. Usage is identical to `contextlib.contextmanager`. diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index 781dcb2ae89ee6..a716f354ad415f 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -17,8 +17,6 @@ import functools import inspect as _inspect -import six - from tensorflow.python.util import tf_decorator @@ -235,7 +233,7 @@ def _get_argspec_for_partial(obj): all_defaults[-len(defaults):] = defaults # Fill in default values provided by partial function in all_defaults. - for kw, default in six.iteritems(partial_keywords): + for kw, default in iter(partial_keywords.items()): if kw in args: idx = args.index(kw) all_defaults[idx] = default diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index a537864036534c..b42e75b2ca2365 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -32,6 +32,26 @@ namespace tensorflow { namespace swig { namespace { +constexpr const char ITERATOR_OPS_MODULE[] = + "tensorflow.python.data.ops.iterator_ops"; +constexpr const char COMPOSITE_TENSOR_MODULE[] = + "tensorflow.python.framework.composite_tensor"; +constexpr const char INDEXED_SLICES_MODULE[] = + "tensorflow.python.framework.indexed_slices"; +constexpr const char OPS_MODULE[] = + "tensorflow.python.framework.ops"; +constexpr const char SPARSE_TENSOR_MODULE[] = + "tensorflow.python.framework.sparse_tensor"; +constexpr const char TENSOR_MODULE[] = + "tensorflow.python.framework.tensor"; +constexpr const char TYPE_SPEC_MODULE[] = + "tensorflow.python.framework.type_spec"; +constexpr const char RESOURCE_VAR_MODULE[] = + "tensorflow.python.ops.resource_variable_ops"; +constexpr const char VARIABLES_MODULE[] = + "tensorflow.python.ops.variables"; +constexpr const char CORE_TYPES_MODULE[] = + "tensorflow.python.types.core"; string PyObjectToString(PyObject* o); } // namespace @@ -53,17 +73,6 @@ PyObject* GetRegisteredPyObject(const string& name) { return it->second; } -PyObject* RegisterType(PyObject* type_name, PyObject* type) { - if (!PyType_Check(type)) { - PyErr_SetString(PyExc_TypeError, - tensorflow::strings::StrCat("Expecting a type, got ", - Py_TYPE(type)->tp_name) - .c_str()); - return nullptr; - } - return RegisterPyObject(type_name, type); -} - PyObject* RegisterPyObject(PyObject* name, PyObject* value) { string key; if (PyBytes_Check(name)) { @@ -212,22 +221,31 @@ class CachedTypeCheck { TF_GUARDED_BY(type_to_sequence_map_mu_); }; -// Returns 1 if 'obj' is an instance of 'type_name' -// Returns 0 otherwise. -// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.) -int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) { - PyObject* type_obj = GetRegisteredPyObject(type_name); - if (TF_PREDICT_FALSE(type_obj == nullptr)) { - PyErr_SetString(PyExc_RuntimeError, - tensorflow::strings::StrCat( - type_name, - " type has not been set. " - "Please register the type with the identifier \"", - type_name, "\" using RegisterType.") - .c_str()); - return -1; +PyObject* ImportTypeFromModule(const char* module_name, const char* type_name) { + static PyObject* given_type; + given_type = [module_name, type_name]() { + PyObject* module = PyImport_ImportModule(module_name); + PyObject* attr = + module ? PyObject_GetAttrString(module, type_name) : nullptr; + if (attr == nullptr) { + PyErr_WriteUnraisable(nullptr); + PyErr_Clear(); + } + if (module) Py_DECREF(module); + return attr; + }(); + return given_type; +} + +// Returns true if 'obj' is an instance of 'type_name' +// Returns false otherwise. +int IsInstanceOfGivenType(PyObject* obj, const char* module_name, + const char* type_name) { + PyObject* given_type = ImportTypeFromModule(module_name, type_name); + if (TF_PREDICT_FALSE(given_type == nullptr)) { + return false; } - return PyObject_IsInstance(obj, type_obj); + return PyObject_IsInstance(obj, given_type); } // Returns 1 if `o` is considered a mapping for the purposes of Flatten(). @@ -235,7 +253,7 @@ int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) { // Returns -1 if an error occurred. int IsMappingHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "Mapping"); + return IsInstanceOfGivenType(to_check, "collections.abc", "Mapping"); }); if (PyDict_Check(o)) return true; return check_cache->CachedLookup(o); @@ -245,7 +263,7 @@ int IsMappingHelper(PyObject* o) { // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred. int IsMutableMappingHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "MutableMapping"); + return IsInstanceOfGivenType(to_check, "collections.abc", "MutableMapping"); }); if (PyDict_Check(o)) return true; return check_cache->CachedLookup(o); @@ -256,7 +274,7 @@ int IsMutableMappingHelper(PyObject* o) { // Returns -1 if an error occurred. int IsMappingViewHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "MappingView"); + return IsInstanceOfGivenType(to_check, "collections.abc", "MappingView"); }); return check_cache->CachedLookup(o); } @@ -266,7 +284,7 @@ int IsMappingViewHelper(PyObject* o) { // Returns -1 if an error occurred. int IsObjectProxy(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "ObjectProxy"); + return IsInstanceOfGivenType(to_check, "wrapt", "ObjectProxy"); }); return check_cache->CachedLookup(o); } @@ -309,7 +327,8 @@ int IsCustomNestProtocolDefined(PyObject* o) { // Returns -1 if an error occurred. int IsIndexedSlicesHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "IndexedSlices"); + return IsInstanceOfGivenType(to_check, INDEXED_SLICES_MODULE, + "IndexedSlices"); }); return check_cache->CachedLookup(o); } @@ -319,7 +338,7 @@ int IsIndexedSlicesHelper(PyObject* o) { // Returns -1 if an error occurred. int IsTensorHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "Tensor"); + return IsInstanceOfGivenType(to_check, TENSOR_MODULE, "Tensor"); }); return check_cache->CachedLookup(o); } @@ -329,7 +348,7 @@ int IsTensorHelper(PyObject* o) { // Returns -1 if an error occurred. int IsTensorSpecHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "TensorSpec"); + return IsInstanceOfGivenType(to_check, TENSOR_MODULE, "TensorSpec"); }); return check_cache->CachedLookup(o); } @@ -339,21 +358,21 @@ int IsTensorSpecHelper(PyObject* o) { // Returns -1 if an error occurred. int IsEagerTensorHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "EagerTensor"); + return IsInstanceOfGivenType(to_check, OPS_MODULE, "EagerTensor"); }); return check_cache->CachedLookup(o); } int IsTensorProtocolHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "TensorProtocol"); + return IsInstanceOfGivenType(to_check, CORE_TYPES_MODULE, "TensorProtocol"); }); return check_cache->CachedLookup(o); } int IsCoreTypeValueHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "CoreTypeValue"); + return IsInstanceOfGivenType(to_check, CORE_TYPES_MODULE, "Value"); }); return check_cache->CachedLookup(o); } @@ -363,7 +382,8 @@ int IsCoreTypeValueHelper(PyObject* o) { // Returns -1 if an error occurred. int IsResourceVariableHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "ResourceVariable"); + return IsInstanceOfGivenType(to_check, RESOURCE_VAR_MODULE, + "ResourceVariable"); }); return check_cache->CachedLookup(o); } @@ -373,7 +393,8 @@ int IsResourceVariableHelper(PyObject* o) { // Returns -1 if an error occurred. int IsOwnedIteratorHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "OwnedIterator"); + return IsInstanceOfGivenType(to_check, ITERATOR_OPS_MODULE, + "OwnedIterator"); }); return check_cache->CachedLookup(o); } @@ -383,7 +404,7 @@ int IsOwnedIteratorHelper(PyObject* o) { // Returns -1 if an error occurred. int IsVariableHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return IsInstanceOfRegisteredType(to_check, "Variable"); + return IsInstanceOfGivenType(to_check, VARIABLES_MODULE, "Variable"); }); return check_cache->CachedLookup(o); } @@ -399,7 +420,8 @@ int IsNestedHelper(PyObject* o) { if (IsCustomNestProtocolDefined(o)) return true; static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence"); + int is_instance = + IsInstanceOfGivenType(to_check, "collections.abc", "Sequence"); // Don't cache a failed is_instance check. if (is_instance == -1) return -1; @@ -617,11 +639,10 @@ class CustomNestedIterator : public ValueIterator { bool IsSparseTensorValueType(PyObject* o) { PyObject* sparse_tensor_value_type = - GetRegisteredPyObject("SparseTensorValue"); + ImportTypeFromModule(SPARSE_TENSOR_MODULE, "SparseTensorValue"); if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) { return false; } - return PyObject_TypeCheck( o, reinterpret_cast(sparse_tensor_value_type)) == 1; } @@ -632,7 +653,8 @@ bool IsSparseTensorValueType(PyObject* o) { bool IsCompositeTensorHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { // TODO(b/246438937): Remove the ResourceVariable test. - return IsInstanceOfRegisteredType(to_check, "CompositeTensor") && + return IsInstanceOfGivenType(to_check, COMPOSITE_TENSOR_MODULE, + "CompositeTensor") && !IsResourceVariable(to_check); }); return check_cache->CachedLookup(o); @@ -644,10 +666,12 @@ bool IsCompositeTensorHelper(PyObject* o) { // Returns -1 if an error occurred. bool IsTypeSpecHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec"); + int is_type_spec = + IsInstanceOfGivenType(to_check, TYPE_SPEC_MODULE, "TypeSpec"); // TODO(b/246438937): Remove the VariableSpec special case. - int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") || - IsInstanceOfRegisteredType(to_check, "VariableSpec")); + int is_dense_spec = + (IsInstanceOfGivenType(to_check, TENSOR_MODULE, "TensorSpec") || + IsInstanceOfGivenType(to_check, RESOURCE_VAR_MODULE, "VariableSpec")); if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1; return static_cast(is_type_spec && !is_dense_spec); }); @@ -1128,7 +1152,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) { } Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields")); - int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence"); + int is_instance = + IsInstanceOfGivenType(fields.get(), "collections.abc", "Sequence"); if (is_instance == 0) { Py_RETURN_FALSE; } else if (is_instance == -1) { diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index fd58430cf8233d..903ddb0f4d1ea1 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -244,9 +244,6 @@ PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2, // the documentation for `RegisteredPyObjects`. Returns PyNone. PyObject* RegisterPyObject(PyObject* name, PyObject* value); -// Variant of RegisterPyObject that requires the object's value to be a type. -PyObject* RegisterType(PyObject* type_name, PyObject* type); - // Returns a borrowed reference to an object that was registered with // RegisterPyObject. (Do not call Py_DECREF on the result). PyObject* GetRegisteredPyObject(const std::string& name); diff --git a/tensorflow/python/util/util_wrapper.cc b/tensorflow/python/util/util_wrapper.cc index 48aa34e72a04a4..5e48eb594d39a1 100644 --- a/tensorflow/python/util/util_wrapper.cc +++ b/tensorflow/python/util/util_wrapper.cc @@ -26,11 +26,6 @@ PYBIND11_MODULE(_pywrap_utils, m) { _pywrap_utils ----- )pbdoc"; - m.def("RegisterType", - [](const py::handle& type_name, const py::handle& type) { - return tensorflow::PyoOrThrow( - tensorflow::swig::RegisterType(type_name.ptr(), type.ptr())); - }); m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) { return tensorflow::PyoOrThrow( tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr())); diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index a4fe30b11d1676..dd704da3a62d11 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -106,6 +106,30 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "enable_multi_host" + number: 27 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "backend_server_port" + number: 28 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "target_tpu" + number: 29 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "target_gpu" + number: 30 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } field { name: "disable_functional_ops_lowering" number: 21 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index 2f53abbe2b3953..c3f36236a34c8b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -235,6 +235,30 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "enable_multi_host" + number: 27 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "backend_server_port" + number: 28 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "target_tpu" + number: 29 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "target_gpu" + number: 30 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } field { name: "disable_functional_ops_lowering" number: 21 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index d517b4a6219751..e8b27d1124aff1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -186,7 +186,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_rank" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index fbebe3b89e42f4..5987b21598a535 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1670,7 +1670,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_band_part" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index f78ba2e0839c78..80e84f38715742 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -398,15 +398,15 @@ tf_module { } member_method { name: "BatchMatMul" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV2" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV3" - argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatrixBandPart" @@ -1936,6 +1936,10 @@ tf_module { name: "GetSessionTensor" argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GlobalIterId" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Greater" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -2300,6 +2304,10 @@ tf_module { name: "ListDiff" argspec: "args=[\'x\', \'y\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "ListSnapshotChunksDataset" + argspec: "args=[\'snapshot_path\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "LoadAndRemapMatrix" argspec: "args=[\'ckpt_path\', \'old_tensor_name\', \'row_remapping\', \'col_remapping\', \'initializing_values\', \'num_rows\', \'num_cols\', \'max_rows_in_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " @@ -2494,7 +2502,7 @@ tf_module { } member_method { name: "MatMul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "MatchingFiles" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt index e649623069d76e..63505344a89afc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-options.pbtxt @@ -14,12 +14,16 @@ tf_class { name: "experimental_io_device" mtype: "" } + member { + name: "experimental_sharding_callback" + mtype: "" + } member { name: "experimental_write_callbacks" mtype: "" } member_method { name: "__init__" - argspec: "args=[\'self\', \'experimental_io_device\', \'experimental_enable_async_checkpoint\', \'experimental_write_callbacks\', \'enable_async\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'experimental_io_device\', \'experimental_enable_async_checkpoint\', \'experimental_write_callbacks\', \'enable_async\', \'experimental_sharding_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt new file mode 100644 index 00000000000000..eeb8a04569157a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-max-shard-size-policy.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.train.experimental.MaxShardSizePolicy" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "description" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_shard_size\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-shard-by-task-policy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-shard-by-task-policy.pbtxt new file mode 100644 index 00000000000000..19c91cb1bc42f3 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-shard-by-task-policy.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.experimental.ShardByTaskPolicy" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "description" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-shardable-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-shardable-tensor.pbtxt new file mode 100644 index 00000000000000..6848e8565c4866 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-shardable-tensor.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.train.experimental.ShardableTensor" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'_tensor_save_spec\', \'tensor\', \'dtype\', \'device\', \'name\', \'shape\', \'slice_spec\', \'checkpoint_key\', \'trackable\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-sharding-callback.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-sharding-callback.pbtxt new file mode 100644 index 00000000000000..583a7f7c3135e9 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-sharding-callback.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.experimental.ShardingCallback" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "description" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt index fc07c4283256e8..c22cacc50d16e0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "LossScale" mtype: "" } + member { + name: "MaxShardSizePolicy" + mtype: "" + } member { name: "MixedPrecisionLossScaleOptimizer" mtype: "" @@ -20,6 +24,18 @@ tf_module { name: "PythonState" mtype: "" } + member { + name: "ShardByTaskPolicy" + mtype: "" + } + member { + name: "ShardableTensor" + mtype: "" + } + member { + name: "ShardingCallback" + mtype: "" + } member_method { name: "disable_mixed_precision_graph_rewrite" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.pbtxt index 1fbfb172de4394..598411258b41a2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.pbtxt @@ -90,15 +90,15 @@ tf_module { } member_method { name: "initialize_accelerator_system" - argspec: "args=[\'device_type\', \'enable_coordination_service\', \'num_logical_cpu_devices\', \'experimental_reset_context\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\'], " + argspec: "args=[\'device_type\', \'enable_coordination_service\', \'num_logical_cpu_devices\', \'experimental_reset_context\', \'experimental_enable_megcore\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'False\'], " } member_method { name: "initialize_multi_client" - argspec: "args=[\'device_type\', \'enable_coordination_service\', \'num_logical_cpu_devices\', \'experimental_reset_context\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\'], " + argspec: "args=[\'device_type\', \'enable_coordination_service\', \'num_logical_cpu_devices\', \'experimental_reset_context\', \'experimental_enable_megcore\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'False\'], " } member_method { name: "initialize_tpu_system" - argspec: "args=[\'device_type\', \'enable_coordination_service\', \'num_logical_cpu_devices\', \'experimental_reset_context\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\'], " + argspec: "args=[\'device_type\', \'enable_coordination_service\', \'num_logical_cpu_devices\', \'experimental_reset_context\', \'experimental_enable_megcore\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'False\', \'False\'], " } member_method { name: "is_dtensor" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index 2319f6abb046b6..b1861f63d55b8d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -198,7 +198,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_rank" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt index 675bb89d694de6..15cdd2e274e29b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " + argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'adaptive_epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " } member_method { name: "add_variable" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt index d31bab3e3d8c7d..fb2ea437049b45 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " + argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'adaptive_epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " } member_method { name: "add_variable" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index c514ae513bd6e3..60f091cdb9c303 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -830,7 +830,7 @@ tf_module { } member_method { name: "matmul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'output_type\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\', \'False\', \'False\', \'None\'], " } member_method { name: "matrix_square_root" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index f78ba2e0839c78..80e84f38715742 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -398,15 +398,15 @@ tf_module { } member_method { name: "BatchMatMul" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV2" - argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatMulV3" - argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'x\', \'y\', \'Tout\', \'adj_x\', \'adj_y\', \'grad_x\', \'grad_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "BatchMatrixBandPart" @@ -1936,6 +1936,10 @@ tf_module { name: "GetSessionTensor" argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GlobalIterId" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Greater" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -2300,6 +2304,10 @@ tf_module { name: "ListDiff" argspec: "args=[\'x\', \'y\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "ListSnapshotChunksDataset" + argspec: "args=[\'snapshot_path\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "LoadAndRemapMatrix" argspec: "args=[\'ckpt_path\', \'old_tensor_name\', \'row_remapping\', \'col_remapping\', \'initializing_values\', \'num_rows\', \'num_cols\', \'max_rows_in_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " @@ -2494,7 +2502,7 @@ tf_module { } member_method { name: "MatMul" - argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'grad_a\', \'grad_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " } member_method { name: "MatchingFiles" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt index 2d3ef0b3fbb669..1d36dacaff7eea 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt @@ -62,7 +62,7 @@ tf_module { } member_method { name: "trace_on" - argspec: "args=[\'graph\', \'profiler\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], " + argspec: "args=[\'graph\', \'profiler\', \'profiler_outdir\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " } member_method { name: "write" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt index e649623069d76e..63505344a89afc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-options.pbtxt @@ -14,12 +14,16 @@ tf_class { name: "experimental_io_device" mtype: "" } + member { + name: "experimental_sharding_callback" + mtype: "" + } member { name: "experimental_write_callbacks" mtype: "" } member_method { name: "__init__" - argspec: "args=[\'self\', \'experimental_io_device\', \'experimental_enable_async_checkpoint\', \'experimental_write_callbacks\', \'enable_async\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'experimental_io_device\', \'experimental_enable_async_checkpoint\', \'experimental_write_callbacks\', \'enable_async\', \'experimental_sharding_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt new file mode 100644 index 00000000000000..eeb8a04569157a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-max-shard-size-policy.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.train.experimental.MaxShardSizePolicy" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "description" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_shard_size\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-shard-by-task-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-shard-by-task-policy.pbtxt new file mode 100644 index 00000000000000..19c91cb1bc42f3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-shard-by-task-policy.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.experimental.ShardByTaskPolicy" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "description" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-shardable-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-shardable-tensor.pbtxt new file mode 100644 index 00000000000000..6848e8565c4866 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-shardable-tensor.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.train.experimental.ShardableTensor" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'_tensor_save_spec\', \'tensor\', \'dtype\', \'device\', \'name\', \'shape\', \'slice_spec\', \'checkpoint_key\', \'trackable\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-sharding-callback.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-sharding-callback.pbtxt new file mode 100644 index 00000000000000..583a7f7c3135e9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.-sharding-callback.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.experimental.ShardingCallback" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "description" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt index 2761b489b965ad..1306e29aa98256 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt @@ -1,7 +1,23 @@ path: "tensorflow.train.experimental" tf_module { + member { + name: "MaxShardSizePolicy" + mtype: "" + } member { name: "PythonState" mtype: "" } + member { + name: "ShardByTaskPolicy" + mtype: "" + } + member { + name: "ShardableTensor" + mtype: "" + } + member { + name: "ShardingCallback" + mtype: "" + } } diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 b/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 index 5090e739c1daa7..7710379719be1d 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 @@ -1,4 +1,4 @@ -FROM linaro/tensorflow-arm64-build:2.15-multipython +FROM linaro/tensorflow-arm64-build:2.16-multipython ARG py_major_minor_version='3.10' diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython new file mode 100644 index 00000000000000..b99e17355729b8 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython @@ -0,0 +1,44 @@ +# Dockerfile to build a manylinux 2010 compliant cross-compiler. +# +# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible +# glibc (2.12) and system libstdc++ (4.4). +# +# To push a new version, run: +# $ docker build -f Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython \ +# --tag "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython" . +# $ docker push gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython + +FROM gcr.io/tensorflow-sigs/build@sha256:1aa3486c05856d76810dc725a26fc9262ab75dd888169d101e5612bf0800c970 + +ENV DEBIAN_FRONTEND=noninteractive + +COPY install/install_bootstrap_deb_packages.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh + +COPY install/install_deb_packages.sh /install/ +RUN /install/install_deb_packages.sh + +RUN apt-get update && apt-get install -y \ + libbz2-dev \ + libffi-dev \ + libgdbm-dev \ + libncurses5-dev \ + libnss3-dev \ + libreadline-dev \ + libsqlite3-dev \ + patchelf \ + && \ + rm -rf /var/lib/apt/lists/* + +COPY install/build_and_install_python.sh /install/ +RUN /install/build_and_install_python.sh "3.9.18" +RUN /install/build_and_install_python.sh "3.10.13" +RUN /install/build_and_install_python.sh "3.11.6" +RUN /install/build_and_install_python.sh "3.12.0" + +COPY install/install_pip_packages_by_version.sh /install/ +# https://github.com/numpy/numpy/issues/22623 for `SETUPTOOLS_USE_DISTUTILS`. +RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" "jax" +RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" "jax" +RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.11" "jax" +RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.12" "jax" diff --git a/tensorflow/tools/ci_build/a100/nightly.sh b/tensorflow/tools/ci_build/a100/nightly.sh index d2ca9a3ae86cdf..6914b0269bd564 100644 --- a/tensorflow/tools/ci_build/a100/nightly.sh +++ b/tensorflow/tools/ci_build/a100/nightly.sh @@ -18,4 +18,4 @@ set -e docker pull tensorflow/tensorflow:devel-gpu docker run --gpus all -w /tensorflow_src -v $PWD:/mnt -e HOST_PERMS="$(id -u):$(id -g)" \ - tensorflow/tensorflow:devel-gpu bash -c "git pull; bazel test --config=cuda -c opt --test_tag_filters=gpu,-no_gpu,-benchmark-test,-no_oss,-oss_excluded,-oss_serial,-v1only,-no_gpu_presubmit,-no_cuda11 -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/mlir/tosa/... -//tensorflow/compiler/xrt/... //tensorflow/compiler/mlir/lite/... -//tensorflow/lite/micro/examples/... -//tensorflow/core/tpu/... -//tensorflow/lite/..." + tensorflow/tensorflow:devel-gpu bash -c "git pull; bazel test --config=cuda -c opt --test_tag_filters=gpu,-no_gpu,-benchmark-test,-no_oss,-oss_excluded,-oss_serial,-v1only,-no_gpu_presubmit,-no_cuda11 -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/mlir/tosa/... //tensorflow/compiler/mlir/lite/... -//tensorflow/lite/micro/examples/... -//tensorflow/core/tpu/... -//tensorflow/lite/..." diff --git a/tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh b/tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh index 631c869c80decf..aca2745bd3e8cf 100755 --- a/tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh +++ b/tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh @@ -17,7 +17,6 @@ set -x DEFAULT_BAZEL_TARGETS="//tensorflow/... \ -//tensorflow/compiler/tf2tensorrt/... \ --//tensorflow/compiler/xrt/... \ -//tensorflow/core/tpu/... \ -//tensorflow/go/... \ -//tensorflow/java/... \ diff --git a/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh b/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh index 1b4cc0552274d9..1a3fec1e179f7f 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh @@ -42,6 +42,7 @@ JAX_PACKAGES=( "typing_extensions" "ml_dtypes>=0.3.0" "importlib_metadata>=4.6" + "flatbuffers" ) PACKAGES=( diff --git a/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc b/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc index c388f5322abf07..11e64b54f97100 100644 --- a/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc +++ b/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc @@ -37,4 +37,4 @@ test --build_tests_only --keep_going test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:nonpip_filters --test_lang_filters=cc,py -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test diff --git a/tensorflow/tools/ci_build/release/requirements_mac.txt b/tensorflow/tools/ci_build/release/requirements_mac.txt index 39349a3f3a6aa2..aa08a8c8db45e3 100644 --- a/tensorflow/tools/ci_build/release/requirements_mac.txt +++ b/tensorflow/tools/ci_build/release/requirements_mac.txt @@ -8,5 +8,5 @@ twine ~= 3.6.0 setuptools # Test dependencies which don't exist on Windows -jax ~= 0.3.24 +jax ~= 0.4.1 jaxlib ~= 0.4.1 diff --git a/tensorflow/tools/ci_build/release/requirements_ubuntu.txt b/tensorflow/tools/ci_build/release/requirements_ubuntu.txt index 8d7122076fcd91..db2e1ee8b47fca 100644 --- a/tensorflow/tools/ci_build/release/requirements_ubuntu.txt +++ b/tensorflow/tools/ci_build/release/requirements_ubuntu.txt @@ -5,5 +5,5 @@ PyYAML ~= 6.0 # Test dependencies which don't exist on Windows -jax ~= 0.3.14 +jax ~= 0.4.1 jaxlib ~= 0.4.1; platform.machine != 'aarch64' diff --git a/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh new file mode 100644 index 00000000000000..25a2f1d4cb44f7 --- /dev/null +++ b/tensorflow/tools/ci_build/windows/bazel/cpu_win_test.sh @@ -0,0 +1,257 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script is a CI script maintained by Intel and is used to launch the nightly CI test +# build on the Windows platform. +# It assumes the standard setup on tensorflow Jenkins Windows machines. +# Update the flags/variables below to make it work on your local system. + +# REQUIREMENTS: +# * All installed in standard locations: +# - JDK8, and JAVA_HOME set. +# - Microsoft Visual Studio 2015 Community Edition +# - Msys2 +# - Python 3.x (with pip, setuptools, venv) +# * Bazel Windows executable copied as "bazel.exe" and included in PATH. + + +# All commands should be visible (-x). +set -x + +POSITIONAL_ARGS=() +XBF_ARGS="" +XTF_ARGS="" +while [[ $# -gt 0 ]]; do + case "$1" in + --extra_build_flags) + XBF_ARGS="$2" + shift # past argument + shift # past value + ;; + --extra_test_flags) + XTF_ARGS="$2" + shift # past argument + shift # past value + ;; + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + +# Bazelisk (renamed as bazel) is kept in C:\Tools +export PATH=/c/ProgramData/chocolatey/bin:/c/Tools/bazel:/c/Program\ Files/Git:/c/Program\ \ +Files/Git/cmd:/c/msys64:/c/msys64/usr/bin:/c/Windows/system32:/c/Windows:/c/Windows/System32/Wbem + +# Environment variables to be set by Jenkins before calling this script + +export PYTHON_VERSION=${PYTHON_VERSION:-"310"} +export TF_PYTHON_VERSION=${PYTHON_VERSION:0:1}.${PYTHON_VERSION:1} +# keep the tensorflow git repo clone under here as tensorflow subdir +MYTFWS_ROOT=${WORKSPACE:-"C:/Users/mlp_admin"} +MYTFWS_ROOT=`cygpath -m $MYTFWS_ROOT` +export MYTFWS_ROOT="$MYTFWS_ROOT" +export MYTFWS_NAME="tensorflow" +export MYTFWS="${MYTFWS_ROOT}/${MYTFWS_NAME}" +export MYTFWS_ARTIFACT="${MYTFWS_ROOT}/artifact" + + +# Import General Test Target +source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh + +# Environment variables specific to the system where this job is running, are to +# be set by a script for the specific system. This needs to be set here by sourcing a file. + +export TMP=${TMP:-"${MYTFWS_ROOT}/tmp"} +export TEMP="$TMP" +export TMPDIR=${TMPDIR:-"${MYTFWS}-build"} # used internally by TF build +export TEST_TARGET=${TEST_TARGET:-"${DEFAULT_BAZEL_TARGETS}"} +export MSYS_LOCATION='C:/msys64' +export GIT_LOCATION='C:/Program Files/Git' +export JAVA_LOCATION='C:/Program Files/Eclipse Adoptium/jdk-11.0.14.101-hotspot' +export VS_LOCATION='C:/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools' +export NATIVE_PYTHON_LOCATION="C:/Python${PYTHON_VERSION}" +export PORTSERVER_LOCATION='C:/Program Files/python_portpicker/src/portserver.py' + + +echo "*** *** hostname is $(hostname) *** ***" +which bazel +which git +[[ -e "$NATIVE_PYTHON_LOCATION/python.exe" ]] || \ +{ echo "Specified Python path is incorrect: $NATIVE_PYTHON_LOCATION"; exit 1;} +[[ -e "$NATIVE_PYTHON_LOCATION/Scripts/pip.exe" ]] || \ +{ echo "Specified Python path has no pip: $NATIVE_PYTHON_LOCATION"; exit 1;} +[[ -e "$NATIVE_PYTHON_LOCATION/Lib/venv" ]] || \ +{ echo "Specified Python path has no venv: $NATIVE_PYTHON_LOCATION"; exit 1;} + +$NATIVE_PYTHON_LOCATION/python.exe -m pip list + +# =========================== Start of actual script ========================= +# This script sets necessary environment variables and runs TF-Windows build & unit tests +# We also assume a few Software components are also installed in the machine: MS VC++, +# MINGW SYS64, Python 3.x, JAVA, Git, Bazelisk etc. + +# Asuumptions +# 1) TF repo cloned into to %WORKSPACE%\tensorflow (aka %TF_LOCATION%) +# 2) Bazelisk is installed in "C:\Tools\Bazel" +# 3) The following jobs-specific env vars will be exported by the caller +# WORKSPACE (ex. C:\Jenkins\workspace\tensorflow-eigen-test-win) +# PYTHON_VERSION (ex. 38) +# PIP_MODULES (if set will contain any additional pip packages) +# 4) System-specific env variables for the location of different software +# components needed for building. + +# Create Python virtual env +cd ${MYTFWS_ROOT} +export PYTHON_DIRECTORY="${MYTFWS_ROOT}"/venv_py${PYTHON_VERSION} +"${NATIVE_PYTHON_LOCATION}"/python.exe -mvenv --clear "${PYTHON_DIRECTORY}" + +#activate virtual env +source "${PYTHON_DIRECTORY}"/Scripts/activate + +which python +python --version + +# Install pip modules specs from tensorflow/tools/ci_build/release/requirements_common.txt +python -m pip install -r $MYTFWS/tensorflow/tools/ci_build/release/requirements_common.txt + +# set up other Variables required by Bazel. +export PYTHON_BIN_PATH="${PYTHON_DIRECTORY}"/Scripts/python.exe +export PYTHON_LIB_PATH="${PYTHON_DIRECTORY}"/Lib/site-packages +export BAZEL_VS=${VS_LOCATION} +export BAZEL_VC=${VS_LOCATION}/VC +export JAVA_HOME=${JAVA_LOCATION} +export BAZEL_SH="${MSYS_LOCATION}"/usr/bin/bash.exe + +cd ${MYTFWS_ROOT} +mkdir -p "$TMP" +mv summary.log summary.log.bak +mv test_failures.log test_failures.log.bak +mv test_run.log test_run.log.bak +rm -rf ${MYTFWS_ARTIFACT} +mkdir -p ${MYTFWS_ARTIFACT} + +cd $MYTFWS + +# All commands shall pass +set -e + +# Setting up the environment variables Bazel and ./configure needs +source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \ + || { echo "Failed to source common_env.sh" >&2; exit 1; } + +# load bazel_test_lib.sh +source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \ + || { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; } + +# Recreate an empty bazelrc file under source root +export TMP_BAZELRC=.tmp.bazelrc +rm -f "${TMP_BAZELRC}" +touch "${TMP_BAZELRC}" + +function cleanup { + # Remove all options in .tmp.bazelrc + echo "" > "${TMP_BAZELRC}" +} +trap cleanup EXIT + +# Enable short object file path to avoid long path issues on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi + +run_configure_for_cpu_build + +# Unset so the script continues even if commands fail, needed to correctly process the logs +set +e + +# start the port server before testing so that each invocation of +# portpicker will defer to the single instance of portserver +# Define the batch script content +BATCH_SCRIPT_START=" +@echo off +set SCRIPT_PATH="${PORTSERVER_LOCATION}" +echo Starting the server... +start \"PORTSERVER\" \"%PYTHON_BIN_PATH%\" \"%SCRIPT_PATH%\" +echo Server started. +" +# Save the batch script content to a temporary batch file +BATCH_SCRIPT_FILE="temp_script.bat" +echo "$BATCH_SCRIPT_START" > "$BATCH_SCRIPT_FILE" + +# Run the batch script +cmd.exe /C "$BATCH_SCRIPT_FILE" + +# NUMBER_OF_PROCESSORS is predefined on Windows +N_JOBS="${NUMBER_OF_PROCESSORS}" +bazel --windows_enable_symlinks test \ + --action_env=TEMP=${TMP} --action_env=TMP=${TMP} ${XTF_ARGS} \ + --experimental_cc_shared_library --enable_runfiles --nodistinct_host_configuration \ + --build_tag_filters=-no_pip,-no_windows,-no_oss,-gpu,-tpu \ + --test_tag_filters=-no_windows,-no_oss,-gpu,-tpu \ + --build_tests_only --config=monolithic \ + --dynamic_mode=off --config=xla --config=opt \ + --build_tests_only -k \ + --test_env=PORTSERVER_ADDRESS=@unittest-portserver \ + --repo_env=TF_PYTHON_VERSION=${TF_PYTHON_VERSION} \ + --test_size_filters=small,medium --jobs="${N_JOBS}" --test_timeout=300,450,1200,3600 \ + --flaky_test_attempts=3 --verbose_failures \ + ${POSITIONAL_ARGS[@]} \ + -- ${TEST_TARGET} \ + > run.log 2>&1 + +build_ret_val=$? # Store the ret value + +BATCH_SCRIPT_STOP=" +echo Killing the server... +taskkill /FI \"WindowTitle eq PORTSERVER*\" /F /T +echo Server killed. +" +BATCH_SCRIPT_FILEl="temp_script.bat" +echo "$BATCH_SCRIPT_STOP" > "$BATCH_SCRIPT_FILEl" +cmd.exe /C "$BATCH_SCRIPT_FILEl" + +# Removing the temporary batch script +rm -f "$BATCH_SCRIPT_FILE" +rm -f "$BATCH_SCRIPT_FILEl" + +# process results +cd $MYTFWS_ROOT + +# Check to make sure the log was created +[ ! -f "${MYTFWS}"/run.log ] && exit 1 + +# handle logs for unit test +cd ${MYTFWS_ARTIFACT} +cp "${MYTFWS}"/run.log ./test_run.log + +fgrep "FAILED: Build did NOT complete" test_run.log > summary.log +fgrep "Executed" test_run.log >> summary.log + +[ $build_ret_val -eq 0 ] && exit 0 + +echo "FAILED TESTS:" > test_failures.log +fgrep "FAILED" test_run.log | grep " ms)" | sed -e 's/^.*\] //' -e 's/ .*$//' | sort | \ +uniq >> test_failures.log +echo >> test_failures.log +echo "SKIPPED TESTS:" >> test_failures.log +fgrep "SKIPPED" test_run.log | grep -v "listed below:" | sed -e 's/^.*\] //' | sort | \ +uniq >> test_failures.log + +exit 1 diff --git a/tensorflow/tools/docs/generate2.py b/tensorflow/tools/docs/generate2.py index d56e508f8bf2dc..ed8ffc015fb7df 100644 --- a/tensorflow/tools/docs/generate2.py +++ b/tensorflow/tools/docs/generate2.py @@ -324,7 +324,7 @@ def edit_yaml_file(path): expected_path_contents = { "tf/summary/audio.md": - "tensorboard/plugins/audio/summary_v2.py", + "python/summary/tb_summary.py", "tf/estimator/DNNClassifier.md": "tensorflow_estimator/python/estimator/canned/dnn.py", "tf/nn/sigmoid_cross_entropy_with_logits.md": diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 37394c6eb9a010..aaf9b0f5f31bb1 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -218,11 +218,7 @@ filegroup( "transform_graph.h", "transform_utils.h", ], - visibility = [ - "//tensorflow/core:__pkg__", - "//tensorflow/python:__pkg__", - "//tensorflow/python/util:__pkg__", - ], + visibility = ["//tensorflow/python/util:__pkg__"], ) cc_library( diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 0a712456f4e609..513b271be55508 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -165,6 +165,7 @@ genrule( ], "//conditions:default": [], }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", @@ -207,6 +208,7 @@ genrule( ], "//conditions:default": [], }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index b926de8e53952a..8b83ce23ab5ef6 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -237,6 +237,7 @@ filegroup( ], "//conditions:default": [], }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", diff --git a/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt b/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt index c0ecfe99bcefff..9ac7ee9b800fd7 100644 --- a/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt +++ b/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt @@ -315,7 +315,7 @@ record keeping.) * are met: * * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. + * notice, this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in @@ -370,21 +370,21 @@ record keeping.) * This package is an SSL implementation written * by Eric Young (eay@cryptsoft.com). * The implementation was written so as to conform with Netscapes SSL. - * + * * This library is free for commercial and non-commercial use as long as * the following conditions are aheared to. The following conditions * apply to all code found in this distribution, be it the RC4, RSA, * lhash, DES, etc., code; not just the SSL code. The SSL documentation * included with this distribution is covered by the same copyright terms * except that the holder is Tim Hudson (tjh@cryptsoft.com). - * + * * Copyright remains Eric Young's, and as such any Copyright notices in * the code are not to be removed. * If this package is used in a product, Eric Young should be given attribution * as the author of the parts of the library used. * This can be in the form of a textual message at program startup or * in documentation (online or textual) provided with the package. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: @@ -399,10 +399,10 @@ record keeping.) * Eric Young (eay@cryptsoft.com)" * The word 'cryptographic' can be left out if the rouines from the library * being used are not cryptographic related :-). - * 4. If you include any Windows specific code (or a derivative thereof) from + * 4. If you include any Windows specific code (or a derivative thereof) from * the apps directory (application code) you must include an acknowledgement: * "This product includes software written by Tim Hudson (tjh@cryptsoft.com)" - * + * * THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE @@ -414,7 +414,7 @@ record keeping.) * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF * SUCH DAMAGE. - * + * * The licence and distribution terms for any publically available version or * derivative of this code cannot be changed. i.e. this code cannot simply be * copied and put under another distribution licence @@ -557,7 +557,40 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER * DEALINGS IN THE SOFTWARE. - */ + */ + +-------------------------------------------------------------------------------- + +-------------------------------------------------------------------------------- +== cutlass + +Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- @@ -931,9 +964,9 @@ the other COPYING.* files here. If you want to guarantee that the Eigen code that you are #including is licensed under the MPL2 and possibly more permissive licenses (like -BSD), #define this preprocessor symbol: EIGEN_MPL2_ONLY +BSD), #define this preprocessor symbol: EIGEN_MPL2_ONLY For example, with most compilers, you could add this to your project - CXXFLAGS: -DEIGEN_MPL2_ONLY + CXXFLAGS: -DEIGEN_MPL2_ONLY This will cause a compilation error to be generated if you #include any code that is covered by more restrictive licences than MPL2. @@ -1693,7 +1726,7 @@ Mozilla Public License Version 2.0 means any form of the work other than Source Code Form. 1.7. "Larger Work" - means a work that combines Covered Software with other material, in + means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" @@ -3591,7 +3624,7 @@ Mozilla Public License Version 2.0 means any form of the work other than Source Code Form. 1.7. "Larger Work" - means a work that combines Covered Software with other material, in + means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" @@ -3952,8 +3985,8 @@ Copyright Notice and Statement for the h5py Project documentation and/or other materials provided with the distribution. - c. Neither the name of the author nor the names of contributors may - be used to endorse or promote products derived from this software + c. Neither the name of the author nor the names of contributors may + be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS @@ -5242,7 +5275,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/llvm/lib/Support/COPYRIGHT.regex: $OpenBSD: COPYRIGHT,v 1.3 2003/06/02 20:18:36 millert Exp $ @@ -5300,7 +5333,7 @@ to the following restrictions: */ ============================================================================== -============================================================================== +============================================================================== License for third_party/llvm/llvm-project/llvm/cmake/config.guess: GNU GENERAL PUBLIC LICENSE @@ -5612,7 +5645,7 @@ exception to the GPL from your modified version. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/llvm/-project/polly/lib/External/isl/LICENSE: MIT License (MIT) @@ -5636,7 +5669,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/llgo/third_party/gotools/LICENSE: Copyright (c) 2009 The Go Authors. All rights reserved. @@ -5668,7 +5701,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/llgo/third_party/gofrontend/libffi/LICENSE: libffi - Copyright (c) 1996-2014 Anthony Green, Red Hat, Inc and others. @@ -5694,7 +5727,7 @@ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/lldb/third_party/Python/module/six/LICENSE: Copyright (c) 2010-2015 Benjamin Peterson @@ -5717,7 +5750,7 @@ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/lldb/third_party/Python/module/pexpect-4.6/LICENSE and lldb/third_party/Python/module/ptyprocess-0.6.0/LICENSE. @@ -5732,7 +5765,7 @@ ISC LICENSE Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. - + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR @@ -5742,7 +5775,7 @@ ISC LICENSE OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/clang-tools-extra/clangd/clients/clangd-vscode/LICENSE: @@ -5769,7 +5802,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/llvm/include/llvm/Support/LICENSE.TXT: LLVM System Interface Library @@ -5780,7 +5813,7 @@ License and has the following additional copyright: Copyright (C) 2004 eXtensible Systems, Inc. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/llvm/test/YAMLParser/LICENSE.txt: Copyright (c) 2006 Kirill Simonov @@ -5804,7 +5837,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/clang-tools-extra/clang-tidy/cert/LICENSE.TXT: ------------------------------------------------------------------------------ @@ -5831,7 +5864,7 @@ to reproduce the title of the content being linked to, nor to reproduce any de Minimis description of such content. ============================================================================== -============================================================================== +============================================================================== Copied from llvm-project/clang-tools-extra/clang-tidy/hicpp/LICENSE.TXT: ------------------------------------------------------------------------------ @@ -6108,21 +6141,21 @@ Copied from docker_kokoro/dockerfiles/scripts/google_packages/deb_packages/copyr Files: libcxx/utils/google-benchmark/* License: Apache 2.0 - + Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ - + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - + 1. Definitions. - + "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. - + "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. - + "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, @@ -6130,24 +6163,24 @@ License: Apache 2.0 direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. - + "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. - + "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. - + "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. - + "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). - + "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications @@ -6155,7 +6188,7 @@ License: Apache 2.0 of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. - + "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally @@ -6169,18 +6202,18 @@ License: Apache 2.0 Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." - + "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. - + 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. - + 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable @@ -6196,24 +6229,24 @@ License: Apache 2.0 or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. - + 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: - + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and - + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and - + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and - + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained @@ -6230,14 +6263,14 @@ License: Apache 2.0 or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. - + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. - + 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of @@ -6245,12 +6278,12 @@ License: Apache 2.0 Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. - + 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. - + 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, @@ -6260,7 +6293,7 @@ License: Apache 2.0 PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. - + 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly @@ -6272,7 +6305,7 @@ License: Apache 2.0 work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. - + 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, @@ -6283,11 +6316,11 @@ License: Apache 2.0 defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - + END OF TERMS AND CONDITIONS - + APPENDIX: How to apply the Apache License to your work. - + To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include @@ -6296,15 +6329,15 @@ License: Apache 2.0 file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. - + Copyright [yyyy] [name of copyright owner] - + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -7334,8 +7367,8 @@ THE POSSIBILITY OF SUCH DAMAGE.** # Components -Many parts of this module have been derived from original sources, -often the algorithm's designer. Component licenses are located with +Many parts of this module have been derived from original sources, +often the algorithm's designer. Component licenses are located with the component code. @@ -9264,33 +9297,33 @@ been taken from other projects or from the open internet. Every line of code can be traced back to its original author, and all of those authors have public domain dedications on file. So the SQLite code base is clean and is uncontaminated with licensed code from other projects. - + -------------------------------------------------------------------------------- -------------------------------------------------------------------------------- == triton -/* +/* * Copyright 2018-2020 Philippe Tillet * Copyright 2020-2022 OpenAI -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, * subject to the following conditions: -* -* The above copyright notice and this permission notice shall be +* +* The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 8a626aa2f887a0..ff21aadba95be6 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -167,12 +167,12 @@ function prepare_src() { cp -L \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.exe.runfiles/org_tensorflow/LICENSE \ "${TMPDIR}" - + # Change the format of file path (TMPDIR-->TMPDIR_rsync) which is input to the rsync from - # Windows-compatible to Linux-compatible to resolve the error below - # error: ssh: Could not resolve hostname c: No such host is known. - - TMPDIR_rsync=`cygpath $TMPDIR` + # Windows-compatible to Linux-compatible to resolve the error below + # error: ssh: Could not resolve hostname c: No such host is known. + + TMPDIR_rsync=`cygpath $TMPDIR` rsync -a \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.exe.runfiles/org_tensorflow/tensorflow \ "${TMPDIR_rsync}" @@ -215,23 +215,6 @@ function prepare_src() { cp -L \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/LICENSE \ "${TMPDIR}" - # Check if it is a tpu build - if [[ ${TPU_BUILD} == "1" ]]; then - # Check if libtpu.so exists - if [[ -f "./tensorflow/lib/libtpu.so" ]]; then - if [[ ! -L "${RUNFILES}/tensorflow/lib/libtpu.so" ]]; then - mkdir "$(real_path ${RUNFILES}/tensorflow/lib)" - ln -s $(real_path ./tensorflow/lib/libtpu.so) $(real_path ${RUNFILES}/tensorflow/lib/libtpu.so) - echo "Created symlink: $(real_path ./tensorflow/lib/libtpu.so) -> \ - $(real_path ${RUNFILES}/tensorflow/lib/libtpu.so)" - else - echo "Symlink already exists: ${RUNFILES}/tensorflow/lib/libtpu.so" - fi - else - echo "Libtpu.so is not found in $(real_path ./tensorflow/lib/)" - exit 1 - fi - fi cp -LR \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" @@ -263,11 +246,17 @@ function prepare_src() { chmod +rw ${TMPDIR}/tensorflow/python/_pywrap_tensorflow_internal.so else chmod +rw ${TMPDIR}/tensorflow/python/_pywrap_tensorflow_internal.so + chmod +rw ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.so chmod +rw ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.so + chmod +rw ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.so patchelf --set-rpath $(patchelf --print-rpath ${TMPDIR}/tensorflow/python/_pywrap_tensorflow_internal.so):\$ORIGIN/../../tensorflow/tsl/python/lib/core ${TMPDIR}/tensorflow/python/_pywrap_tensorflow_internal.so + patchelf --set-rpath $(patchelf --print-rpath ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.so):\$ORIGIN/../../../../../python ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.so patchelf --set-rpath $(patchelf --print-rpath ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.so):\$ORIGIN/../../../../../python ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.so + patchelf --set-rpath $(patchelf --print-rpath ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.so):\$ORIGIN/../../../../../python ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.so patchelf --shrink-rpath ${TMPDIR}/tensorflow/python/_pywrap_tensorflow_internal.so + patchelf --shrink-rpath ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.so patchelf --shrink-rpath ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.so + patchelf --shrink-rpath ${TMPDIR}/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/pywrap_calibration.so fi mkl_so_dir=$(ls ${RUNFILES}/${so_lib_dir} | grep mkl) || true if [ -n "${mkl_so_dir}" ]; then @@ -358,7 +347,7 @@ function build_wheel() { FULL_DIR="$(real_path "$PY_DIR")/bin/python3" export PYTHONPATH="$PYTHONPATH:$PWD/bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/pypi_wheel/site-packages/" fi - + pushd ${TMPDIR} > /dev/null rm -f MANIFEST diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index a5a20b8dcf78e7..57faa7fb4ae7f2 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,6 +29,7 @@ 2.0](https://github.com/tensorflow/tensorflow/blob/master/LICENSE). """ +import datetime import fnmatch import os import platform @@ -178,9 +179,6 @@ def standard_or_nightly(standard, nightly): 'nvidia-cusparse-cu12 == 12.1.2.141', 'nvidia-nccl-cu12 == 2.18.3', 'nvidia-nvjitlink-cu12 == 12.2.140', - 'tensorrt == 8.6.1.post1', - 'tensorrt-bindings == 8.6.1', - 'tensorrt-libs == 8.6.1', ] DOCLINES = __doc__.split('\n') @@ -322,9 +320,23 @@ def find_files(pattern, root): for path in so_lib_paths: matches.extend(['../' + x for x in find_files('*', path) if '.py' not in x]) -# If building a tpu package, bundle libtpu.so as part of the wheel +# If building a tpu package, LibTPU for Cloud TPU VM can be installed via: +# $ pip install -f https://storage.googleapis.com/libtpu-releases/index.html +# libtpu is built and uploaded to this link every night (PST). if '_tpu' in project_name: - matches.append('tensorflow/lib/libtpu.so') + # For tensorflow-tpu releases, use a set libtpu-nightly version; + # For tf-nightly-tpu, use the most recent libtpu-nightly. Because of the + # timing of these tests, the UTC date from eight hours ago is expected to be a + # valid version. + _libtpu_version = standard_or_nightly( + '0.1.dev20231018', + '0.1.dev' + + ( + datetime.datetime.now(tz=datetime.timezone.utc) + - datetime.timedelta(hours=8) + ).strftime('%Y%m%d'), + ) + REQUIRED_PACKAGES.append([f'libtpu-nightly=={_libtpu_version}']) if os.name == 'nt': EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd' @@ -422,6 +434,7 @@ def find_files(pattern, root): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', diff --git a/tensorflow/tools/proto_splitter/BUILD b/tensorflow/tools/proto_splitter/BUILD index 5a74f394eb9dd1..447bae692394e7 100644 --- a/tensorflow/tools/proto_splitter/BUILD +++ b/tensorflow/tools/proto_splitter/BUILD @@ -100,6 +100,7 @@ py_strict_test( ":chunk_proto_py", ":split", ":versions_proto_py", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", "//tensorflow/tools/proto_splitter/testdata:test_message_proto_py", "@riegeli_py//python/riegeli", @@ -136,6 +137,7 @@ py_strict_test( ":constants", ":split_graph_def", ":util", + #internal proto upb dep "//tensorflow/core:protos_all_py", "//tensorflow/python/platform:client_testlib", "//tensorflow/tools/proto_splitter/python:test_util", @@ -153,6 +155,7 @@ py_strict_test( srcs = ["util_test.py"], deps = [ ":util", + #internal proto upb dep "//tensorflow/python/platform:client_testlib", "//tensorflow/tools/proto_splitter/testdata:test_message_proto_py", ], diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index 266d69479ff8a0..1188ed94533864 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -1,12 +1,13 @@ -# Description: -# Utilities for splitting and joining large protos > 2GB. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test", ) +# Description: +# Utilities for splitting and joining large protos > 2GB. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ @@ -43,11 +44,16 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/tools/proto_splitter:chunk_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", + "@riegeli//riegeli/bytes:cord_writer", "@riegeli//riegeli/bytes:fd_writer", + "@riegeli//riegeli/bytes:string_writer", "@riegeli//riegeli/records:record_writer", ] + if_oss([ "//tensorflow/tools/proto_splitter:protos_impl", @@ -85,11 +91,15 @@ tf_cc_test( "//tensorflow/tools/proto_splitter:chunk_proto_cc", "//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", + "@riegeli//riegeli/bytes:cord_reader", "@riegeli//riegeli/bytes:fd_reader", + "@riegeli//riegeli/bytes:string_reader", "@riegeli//riegeli/records:record_reader", ] + if_oss([ "//tensorflow/tools/proto_splitter:protos_impl", @@ -197,6 +207,7 @@ cc_library( ":composable_splitter", ":max_size", ":size_splitter", + ":split", ":util", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/status", diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc index 8a9ee3091a1366..b02c09c6fa8d62 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc @@ -1,5 +1,7 @@ #include "tensorflow/tools/proto_splitter/cc/composable_splitter_base.h" +#include + /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,25 +16,39 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include #include +#include #include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "riegeli/bytes/cord_writer.h" // from @riegeli #include "riegeli/bytes/fd_writer.h" // from @riegeli +#include "riegeli/bytes/string_writer.h" // from @riegeli #include "riegeli/records/record_writer.h" // from @riegeli #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" +#include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +#define IS_OSS true namespace tensorflow { namespace tools::proto_splitter { @@ -86,27 +102,67 @@ ComposableSplitterBase::Split() { return std::make_pair(&chunks_, &chunked_message_); } -absl::Status ComposableSplitterBase::Write(std::string file_prefix) { +template +static absl::Status WriteToRecordWriter( + riegeli::RecordWriter& writer, const std::vector& chunks, + ChunkedMessage& chunked_message, + const ::proto_splitter::VersionDef& version) { + // Export Riegeli / chunked file. + ChunkMetadata metadata; + *metadata.mutable_message() = chunked_message; + *metadata.mutable_version() = version; + auto* metadata_chunks = metadata.mutable_chunks(); + + for (const auto& chunk : chunks) { + auto* chunk_metadata = metadata_chunks->Add(); + if (std::holds_alternative>( + chunk)) { + const auto& msg_chunk = + std::get>(chunk); + LOG(INFO) << "Writing chunk of size " << msg_chunk->ByteSizeLong(); + writer.WriteRecord(*msg_chunk); + chunk_metadata->set_size(msg_chunk->ByteSizeLong()); + chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); + } else if (std::holds_alternative(chunk)) { + auto* msg_chunk = std::get(chunk); + writer.WriteRecord(*msg_chunk); + chunk_metadata->set_size(msg_chunk->ByteSizeLong()); + chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); + } else { + const auto& str_chunk = std::get(chunk); + writer.WriteRecord(str_chunk); + chunk_metadata->set_size(str_chunk.size()); + chunk_metadata->set_type(::proto_splitter::ChunkInfo::BYTES); + } + chunk_metadata->set_offset(writer.LastPos().get().numeric()); + } + writer.WriteRecord(metadata); + return absl::OkStatus(); +} + +absl::Status ComposableSplitterBase::CheckIfWriteImplemented() { if (parent_splitter_ != nullptr) { return absl::UnimplementedError( "The `Write` function behavior for children ComposableSplitter has not " - "been defined. Please call the parent ComposableSplitter's `Write` " - "instead."); - } - auto split_status = Split(); - if (!split_status.ok()) { - return split_status.status(); + "been defined. Please call `parent_splitter.Write()` instead."); } + return absl::OkStatus(); +} - auto chunks = split_status.value().first; - auto chunked_message = split_status.value().second; +absl::Status ComposableSplitterBase::Write(std::string file_prefix) { + TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); + + auto split_results = Split(); + if (!split_results.ok()) return split_results.status(); + auto& chunks = *split_results.value().first; + auto& chunked_message = *split_results.value().second; tsl::Env* env = tsl::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir( std::string{tensorflow::io::Dirname(file_prefix)})); std::string output_path; - if (chunked_message->chunked_fields().empty()) { + if (chunked_message.chunked_fields().empty()) { // Export regular pb. output_path = absl::StrCat(file_prefix, ".pb"); TF_RETURN_IF_ERROR( @@ -114,43 +170,77 @@ absl::Status ComposableSplitterBase::Write(std::string file_prefix) { } else { // Export Riegeli / chunked file. output_path = absl::StrCat(file_prefix, ".cpb"); - riegeli::RecordWriter writer((riegeli::FdWriter(output_path))); - - ChunkMetadata metadata; - metadata.mutable_message()->MergeFrom(*chunked_message); - metadata.mutable_version()->MergeFrom(Version()); - auto metadata_chunks = metadata.mutable_chunks(); - - for (auto chunk : *chunks) { - auto chunk_metadata = metadata_chunks->Add(); - if (std::holds_alternative>( - chunk)) { - auto msg_chunk = - std::get>(chunk); - writer.WriteRecord(*msg_chunk); - chunk_metadata->set_size(msg_chunk->ByteSizeLong()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); - } else if (std::holds_alternative(chunk)) { - auto msg_chunk = std::get(chunk); - writer.WriteRecord(*msg_chunk); - chunk_metadata->set_size(msg_chunk->ByteSizeLong()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); - } else { - auto str_chunk = std::get(chunk); - writer.WriteRecord(str_chunk); - chunk_metadata->set_size(str_chunk.size()); - chunk_metadata->set_type(::proto_splitter::ChunkInfo::BYTES); - } - chunk_metadata->set_offset(writer.LastPos().get().numeric()); - } - - writer.WriteRecord(metadata); + using WriterType = riegeli::FdWriter<>; + riegeli::RecordWriter writer((WriterType(output_path))); + if (!writer.is_open()) return writer.status(); + TF_RETURN_IF_ERROR(WriteToRecordWriter( + writer, chunks, chunked_message, Version())); if (!writer.Close()) return writer.status(); } LOG(INFO) << "Splitter output written to " << output_path; return absl::OkStatus(); } +absl::StatusOr> +ComposableSplitterBase::WriteToString() { + TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); + + auto split_results = Split(); + if (!split_results.ok()) return split_results.status(); + auto& chunks = *split_results.value().first; + auto& chunked_message = *split_results.value().second; + + std::string output; + if (chunked_message.chunked_fields().empty()) { + // Export regular pb. + if (!message_->SerializeToString(&output)) + return absl::InvalidArgumentError("Serialization to string failed"); + LOG(INFO) << "Splitter output written to string"; + return std::make_tuple(output, false); + } else { + // Export Riegeli / chunked file. + using WriterType = riegeli::StringWriter<>; + riegeli::RecordWriter writer((WriterType(&output))); + if (!writer.is_open()) return writer.status(); + TF_RETURN_IF_ERROR(WriteToRecordWriter( + writer, chunks, chunked_message, Version())); + if (!writer.Close()) return writer.status(); + LOG(INFO) << "Splitter output written to string"; + return std::make_tuple(output, true); + } +} + +#if !IS_OSS +absl::StatusOr> +ComposableSplitterBase::WriteToCord() { + TF_RETURN_IF_ERROR(CheckIfWriteImplemented()); + + auto split_results = Split(); + if (!split_results.ok()) return split_results.status(); + auto& chunks = *split_results.value().first; + auto& chunked_message = *split_results.value().second; + + absl::Cord output; + if (chunked_message.chunked_fields().empty()) { + // Export regular pb. + if (!message_->SerializeToCord(&output)) + return absl::InvalidArgumentError("Serialization to absl::Cord failed"); + LOG(INFO) << "Splitter output written to absl::Cord"; + return std::make_tuple(output, false); + } else { + // Export Riegeli / chunked file. + using WriterType = riegeli::CordWriter<>; + riegeli::RecordWriter writer((WriterType(&output))); + if (!writer.is_open()) return writer.status(); + TF_RETURN_IF_ERROR(WriteToRecordWriter( + writer, chunks, chunked_message, Version())); + if (!writer.Close()) return writer.status(); + LOG(INFO) << "Splitter output written to absl::Cord"; + return std::make_tuple(output, true); + } +} +#endif + absl::Status ComposableSplitterBase::SetMessageAsBaseChunk() { if (!chunks_.empty()) { return absl::FailedPreconditionError( diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h index 478638b43fb989..a37a3c61ca0a02 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.h @@ -15,18 +15,23 @@ limitations under the License. #ifndef TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_COMPOSABLE_SPLITTER_BASE_H_ #define TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_COMPOSABLE_SPLITTER_BASE_H_ +#include #include #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tsl/platform/protobuf.h" +#define IS_OSS true + namespace tensorflow { namespace tools::proto_splitter { @@ -62,6 +67,12 @@ class ComposableSplitterBase : public Splitter { // attach a `.pb` or `.cpb` (chunked pb) suffix depending on whether the // proto is split. absl::Status Write(std::string file_prefix) override; + // The bool field record whether it's saved as a chunked protobuf (true) or + // regular protobuf (false). + absl::StatusOr> WriteToString(); +#if !IS_OSS + absl::StatusOr> WriteToCord(); +#endif VersionDef Version() override; @@ -93,6 +104,7 @@ class ComposableSplitterBase : public Splitter { // the chunks were always added to the end of the list. However, this is not // always the case the indices must be updated. absl::Status FixChunks(); + absl::Status CheckIfWriteImplemented(); bool built_; tsl::protobuf::Message* message_; diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc index 85eeab4f5a2dad..8efdf36caee628 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc @@ -16,12 +16,18 @@ limitations under the License. #include #include +#include +#include +#include #include #include #include #include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "riegeli/bytes/cord_reader.h" // from @riegeli #include "riegeli/bytes/fd_reader.h" // from @riegeli +#include "riegeli/bytes/string_reader.h" // from @riegeli #include "riegeli/records/record_reader.h" // from @riegeli #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/env.h" @@ -33,10 +39,11 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +#define IS_OSS true + namespace tensorflow { namespace tools::proto_splitter { namespace { @@ -120,23 +127,9 @@ TEST(RepeatedStringSplitterTest, TestSplitChunks) { EXPECT_EQ(chunked_message2, chunked_message); } -TEST(RepeatedStringSplitterTest, TestWrite) { - std::vector strings = {"piece-1", "piece-2", "piece-3"}; - auto message = SetUpRepeatedString(strings); - RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); - - std::string output_prefix = tensorflow::io::GetTempFilename(""); - TF_ASSERT_OK(splitter.Write(output_prefix)); - std::string expected_file = absl::StrCat(output_prefix, ".cpb"); - - TF_ASSERT_OK_AND_ASSIGN(auto exists, - internal::FileExists(Env::Default(), expected_file)); - EXPECT_TRUE(exists); - - // Look for the last chunk, which should contain a ChunkMetadata proto. - riegeli::RecordReader> reader( - (riegeli::FdReader(expected_file))); - +template +static void CheckChunks(riegeli::RecordReader& reader, + std::vector& strings) { ChunkMetadata chunk_metadata; reader.Seek(reader.Size().value()); reader.SeekBack(); @@ -169,6 +162,60 @@ TEST(RepeatedStringSplitterTest, TestWrite) { })pb")); } +TEST(RepeatedStringSplitterTest, TestWrite) { + std::vector strings = {"piece-1", "piece-2", "piece-3"}; + auto message = SetUpRepeatedString(strings); + RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); + + std::string output_prefix = tensorflow::io::GetTempFilename(""); + TF_ASSERT_OK(splitter.Write(output_prefix)); + std::string expected_file = absl::StrCat(output_prefix, ".cpb"); + + TF_ASSERT_OK_AND_ASSIGN(auto exists, + internal::FileExists(Env::Default(), expected_file)); + EXPECT_TRUE(exists); + + // Look for the last chunk, which should contain a ChunkMetadata proto. + riegeli::RecordReader> file_reader( + (riegeli::FdReader(expected_file))); + + CheckChunks(file_reader, strings); +} + +TEST(RepeatedStringSplitterTest, TestWriteToString) { + std::vector strings = {"piece-1", "piece-2", "piece-3"}; + auto message = SetUpRepeatedString(strings); + RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); + auto string_output_results = splitter.WriteToString(); + TF_EXPECT_OK(string_output_results.status()); + std::string string_output = std::get<0>(string_output_results.value()); + bool is_chunked = std::get<1>(string_output_results.value()); + EXPECT_TRUE(is_chunked); + // Look for the last chunk, which should contain a ChunkMetadata proto. + riegeli::RecordReader> string_reader( + std::forward_as_tuple(string_output)); + + CheckChunks(string_reader, strings); +} + +#if !IS_OSS +TEST(RepeatedStringSplitterTest, TestWriteToCord) { + std::vector strings = {"piece-1", "piece-2", "piece-3"}; + auto message = SetUpRepeatedString(strings); + RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); + auto cord_output_results = splitter.WriteToCord(); + TF_EXPECT_OK(cord_output_results.status()); + absl::Cord cord_output = std::get<0>(cord_output_results.value()); + bool is_chunked = std::get<1>(cord_output_results.value()); + EXPECT_TRUE(is_chunked); + // Look for the last chunk, which should contain a ChunkMetadata proto. + riegeli::RecordReader> cord_reader( + std::forward_as_tuple(&cord_output)); + + CheckChunks(cord_reader, strings); +} +#endif + TEST(RepeatedStringSplitterTest, TestNoSplit) { RepeatedString message; // No strings RepeatedStringSplitter splitter = RepeatedStringSplitter(&message); diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc index 036c90fde04a94..bbb2587a2d3c39 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/graph_def_splitter.h" +#include #include #include #include @@ -179,7 +180,12 @@ TEST(GraphDefSplitterTest, TestLotsNodes) { const std::string graph_def_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "split-lots-nodes.pb"); - int64_t max_size = 500; + + // split-lots-nodes.pb has 15 nodes that are 95 or 96 bytes each. The max size + // is set to "exactly" the size of 5 nodes, but with the extra encoding bytes, + // only 4 nodes should fit in each chunk. Thus, there should be exactly 4 + // chunks created for all 15 nodes. + int64_t max_size = 96 * 5; DebugSetMaxSize(max_size); TF_EXPECT_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), @@ -196,7 +202,9 @@ TEST(GraphDefSplitterTest, TestLotsNodes) { *chunked_message, EqualsProto(R"pb(chunk_index: 0 chunked_fields { message { chunk_index: 1 } } - chunked_fields { message { chunk_index: 2 } })pb")); + chunked_fields { message { chunk_index: 2 } } + chunked_fields { message { chunk_index: 3 } } + chunked_fields { message { chunk_index: 4 } })pb")); auto chunks = x.first; EXPECT_CHUNK_SIZES(chunks, max_size); diff --git a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc index e836556d569974..01601c7e22a1fc 100644 --- a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc +++ b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h" +#include #include +#include #include #include "absl/status/status.h" @@ -23,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" +#include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" @@ -31,6 +34,10 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { +// Additional bytes added to each node to account for the extra info needed to +// encode the field key (realistically 3 but making it 5 for some wiggle room). +constexpr int kExtraBytes = 5; + template absl::StatusOr> RepeatedFieldSplitters::Create( @@ -65,13 +72,8 @@ absl::StatusOr RepeatedFieldSplitters< // List of indices at which to split the repeated field. For example, [3, 5] // means that the field list is split into: [:3], [3:5], [5:] - std::vector repeated_msg_split = {}; - // Should be the same length as the list above. Contains new protos to hold - // the elements that are split from the original proto. - // From the [3, 5] example above, the messages in this list contain nodes - // [3:5] and [5:] - std::vector> repeated_new_msg; - // Track the total size of the current node split. + std::vector repeated_msg_split = {0}; + // Track the total byte size of the current node split. uint64_t total_size = 0; // Linearly iterate through all nodes. It may be possible to optimize this @@ -99,17 +101,12 @@ absl::StatusOr RepeatedFieldSplitters< } if (total_size + node_size > max_size) { repeated_msg_split.push_back(i); - auto new_chunk = std::make_shared(); - repeated_new_msg.push_back(new_chunk); - std::vector empty_fields = {}; - auto x = std::make_unique(new_chunk); - TF_RETURN_IF_ERROR(AddChunk(std::move(x), &empty_fields)); total_size = 0; } - total_size += node_size; + total_size += node_size + kExtraBytes; } - if (!repeated_msg_split.empty()) { + if (repeated_msg_split.size() > 1) { auto repeated_nodes_ptrs = ret.parent->GetReflection() ->template MutableRepeatedPtrField(ret.parent, @@ -127,7 +124,11 @@ absl::StatusOr RepeatedFieldSplitters< for (int i = 1; i < repeated_msg_split.size(); ++i) { start = repeated_msg_split[i - 1]; int end = repeated_msg_split[i]; - std::shared_ptr new_msg = repeated_new_msg[i - 1]; + + auto new_msg = std::make_shared(); + std::vector empty_fields; + auto x = std::make_unique(new_msg); + TF_RETURN_IF_ERROR(AddChunk(std::move(x), &empty_fields)); // Move nodes into new_msg. TF_ASSIGN_OR_RETURN(auto new_ret, diff --git a/tensorflow/tools/proto_splitter/python/BUILD b/tensorflow/tools/proto_splitter/python/BUILD index 18cd31b66bb8d8..b18bd64b0b8cd8 100644 --- a/tensorflow/tools/proto_splitter/python/BUILD +++ b/tensorflow/tools/proto_splitter/python/BUILD @@ -1,8 +1,7 @@ -load("//tensorflow:strict.default.bzl", "py_strict_test") - # Description: # Python library for splitting and joining large protos. load("//tensorflow:pytype.default.bzl", "pytype_strict_library") +load("//tensorflow:strict.default.bzl", "py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,6 +12,11 @@ package( pytype_strict_library( name = "saved_model", srcs = ["saved_model.py"], + # NOTE(yibaimeng): To be removed when everything is migrated to `pywrap_saved_model.Save`. + visibility = [ + "//tensorflow:internal", + "//waymo/ml/deploy/tensorflow:__pkg__", + ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/tools/proto_splitter:constants", @@ -31,6 +35,7 @@ py_strict_test( deps = [ ":saved_model", ":test_util", + #internal proto upb dep "//tensorflow/core:protos_all_py", "//tensorflow/python/platform:client_testlib", "//tensorflow/tools/proto_splitter:constants", @@ -52,6 +57,7 @@ py_strict_test( name = "test_util_test", srcs = ["test_util_test.py"], deps = [ + #internal proto upb dep "//tensorflow/python/framework:dtypes", "//tensorflow/python/platform:client_testlib", "//tensorflow/tools/proto_splitter/python:test_util", diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 43d46c8b5c1e79..d4e12622377a68 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -21,12 +21,10 @@ exports_files([ py_strict_library( name = "system_info_lib", - srcs = [ - "gpu_info_lib.py", - "system_info_lib.py", - ], + srcs = ["system_info_lib.py"], srcs_version = "PY3", deps = [ + ":gpu_info_lib", "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python/client:device_lib", @@ -36,6 +34,19 @@ py_strict_library( ], ) +py_strict_library( + name = "gpu_info_lib", + srcs = ["gpu_info_lib.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/framework:errors", + "//tensorflow/python/platform:gfile", + "@six_archive//:six", + ], +) + py_strict_binary( name = "system_info", srcs = ["system_info.py"], @@ -54,6 +65,7 @@ py_strict_library( ], srcs_version = "PY3", deps = [ + ":gpu_info_lib", ":system_info_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python/platform:gfile", diff --git a/tensorflow/tools/test/performance.bzl b/tensorflow/tools/test/performance.bzl index 4f4201e62e73a4..f918da44589729 100644 --- a/tensorflow/tools/test/performance.bzl +++ b/tensorflow/tools/test/performance.bzl @@ -1,4 +1,12 @@ -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") +""" +Benchmark-related macros. +""" + +load( + "//tensorflow:tensorflow.default.bzl", + "cuda_py_strict_test", + "tf_py_strict_test", +) # Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test) def tf_cc_logged_benchmark( @@ -50,6 +58,33 @@ def tf_cc_logged_benchmark( **kwargs ) +def add_benchmark_tag_to_kwargs(kwargs): + """Adds the `benchmark-test` tag to the kwargs, if not already present. + + Notes: + For benchmarks which are not technically tests, but whose class methods + can still be discovered, and run as such via `bazel run`. + Args: + kwargs: kwargs to be passed to a test wrapper/rule further down. + Returns: + kwargs: kwargs with the tags including the `benchmark-test` tags. + """ + benchmark_tag = "benchmark-test" + if "tags" in kwargs and kwargs["tags"] != None: + if benchmark_tag not in kwargs["tags"]: + kwargs["tags"].append(benchmark_tag) + else: + kwargs["tags"] = [benchmark_tag] + return kwargs + +def tf_py_benchmark_test(**kwargs): + kwargs = add_benchmark_tag_to_kwargs(kwargs) + tf_py_strict_test(**kwargs) + +def cuda_py_benchmark_test(**kwargs): + kwargs = add_benchmark_tag_to_kwargs(kwargs) + cuda_py_strict_test(**kwargs) + # Create a benchmark test target of a TensorFlow python test (*py_tests) def tf_py_logged_benchmark( name = None, diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt index 62e73c996b1829..4a899fb3504e11 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt @@ -42,8 +42,8 @@ scipy ~= 1.7.2; python_version < '3.11' scipy ~= 1.9.2; python_version == '3.11' # Earliest version for Python 3.11 scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12 # Required for TFLite import from JAX tests -jax ~= 0.3.25; python_version <= '3.11' -jaxlib ~= 0.3.25; python_version <= '3.11' # Earliest version for Python 3.11 +jax ~= 0.4.1; python_version <= '3.11' +jaxlib ~= 0.4.1; python_version <= '3.11' # Earliest version for Python 3.11 # Needs to be addressed. Unblocked 2.4 branchcut cl/338377048 PyYAML ~= 6.0 # For uploading diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc index c3a792a539c607..0bfee88d16c710 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc @@ -23,7 +23,7 @@ build --config=release_cpu_linux test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py39,-no_oss_py310 test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py39,-no_oss_py310 test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test @@ -40,4 +40,4 @@ build:rbe --config=rbe_linux_cpu test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc index 14b75645a85fab..1f21969496f1e9 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_gcc.bazelrc @@ -45,7 +45,7 @@ test --test_summary=short test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py39,-no_oss_py310 test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py39,-no_oss_py310 test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test @@ -82,4 +82,4 @@ build:rbe --project_id="tensorflow-testing" test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc index e85df2f297ec0a..8bfadb03c734bd 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc @@ -19,10 +19,10 @@ build --config=release_gpu_linux # Pass --config=nonpip to run the same suite of tests. If you want to run just # one test for investigation, you don't need --config=nonpip; just run the # bazel test invocation as normal. -test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py39,-no_oss_py310 -test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py39,-no_oss_py310 +test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py39,-no_oss_py310 test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium -test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives test:libtensorflow_test -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test @@ -39,4 +39,4 @@ build:rbe --config=rbe_linux_cuda test:pycpp_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:pycpp_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:pycpp_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:pycpp --config=pycpp_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/rename_and_verify_wheels.sh b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/rename_and_verify_wheels.sh index dd7c1524ba9bec..1ba11e07f53a10 100755 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/rename_and_verify_wheels.sh +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/rename_and_verify_wheels.sh @@ -21,11 +21,7 @@ set -euxo pipefail for wheel in /tf/pkg/*.whl; do echo "Checking and renaming $wheel..." - if [[ "$wheel" =~ .*_tpu.* ]]; then - time python3 -m auditwheel repair --plat manylinux_2_27_x86_64 "$wheel" --wheel-dir /tf/pkg 2>&1 | tee check.txt - else - time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir /tf/pkg 2>&1 | tee check.txt - fi + time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir /tf/pkg 2>&1 | tee check.txt # We don't need the original wheel if it was renamed new_wheel=$(grep --extended-regexp --only-matching '/tf/pkg/\S+.whl' check.txt) diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats index 19662eb904bfb5..17b689dbedd6dd 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/wheel_verification.bats @@ -26,13 +26,9 @@ teardown_file() { rm -rf /tf/venv } -@test "Wheel is manylinux2014 (manylinux_2_17) compliant (TPU wheel is manylinux_2_27 compliant)" { +@test "Wheel is manylinux2014 (manylinux_2_17) compliant" { python3 -m auditwheel show "$TF_WHEEL" > audit.txt - if [[ "$TF_WHEEL" =~ .*_tpu.* ]]; then - grep --quiet 'This constrains the platform tag to "manylinux_2_27_x86_64"' audit.txt - else - grep --quiet 'This constrains the platform tag to "manylinux_2_17_x86_64"' audit.txt - fi + grep --quiet 'This constrains the platform tag to "manylinux_2_17_x86_64"' audit.txt } @test "Wheel conforms to upstream size limitations" { @@ -58,10 +54,12 @@ teardown_file() { # Note: this runs before the tests further down the file, so TF is installed in # the venv and the venv is active when those tests run. The venv gets cleaned # up in teardown_file() above. +# LibTPU is necessary if building a tpu package, and it is installed via +# "-f ". See tensorflow/setup.py. @test "Wheel is installable" { python3 -m venv /tf/venv source /tf/venv/bin/activate - python3 -m pip install "$TF_WHEEL" + python3 -m pip install "$TF_WHEEL" -f https://storage.googleapis.com/libtpu-releases/index.html } @test "TensorFlow is importable" { diff --git a/tensorflow/tools/toolchains/cross_compile/cc/BUILD b/tensorflow/tools/toolchains/cross_compile/cc/BUILD new file mode 100644 index 00000000000000..7db2527259d026 --- /dev/null +++ b/tensorflow/tools/toolchains/cross_compile/cc/BUILD @@ -0,0 +1,188 @@ +"""Toolchain configs for cross-compiling TensorFlow""" + +load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +cc_toolchain_suite( + name = "cross_compile_toolchain_suite", + toolchains = { + "aarch64": ":linux_aarch64_toolchain", + "k8": ":linux_x86_toolchain", + }, +) + +filegroup(name = "empty") + +cc_toolchain( + name = "linux_x86_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_x86_toolchain_config", + toolchain_identifier = "linux_x86_toolchain", +) + +cc_toolchain_config( + name = "linux_x86_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt9", + compile_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mavx", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "k8", + cxx_builtin_include_directories = [ + "/dt9/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "x86_64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) + +cc_toolchain( + name = "linux_aarch64_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_aarch64_toolchain_config", + toolchain_identifier = "linux_aarch64_toolchain", +) + +cc_toolchain_config( + name = "linux_aarch64_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt10/", + compile_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mtune=generic", + "-march=armv8-a", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "aarch64", + cxx_builtin_include_directories = [ + "/dt10/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "aarch64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_aarch64_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/tensorflow/tools/toolchains/cross_compile/config/BUILD b/tensorflow/tools/toolchains/cross_compile/config/BUILD new file mode 100644 index 00000000000000..b6a504ba1449d6 --- /dev/null +++ b/tensorflow/tools/toolchains/cross_compile/config/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + exec_properties = { + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "OSFamily": "Linux", + }, +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index e8fc081f0af511..a7cbf50e47eea3 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -200,6 +200,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_win_config( name = "windows_py37", python_bin_path = "C:/Python37/python.exe", diff --git a/tensorflow/tools/toolchains/remote_config/containers.bzl b/tensorflow/tools/toolchains/remote_config/containers.bzl index bfb4634e810328..cd346c2816def1 100644 --- a/tensorflow/tools/toolchains/remote_config/containers.bzl +++ b/tensorflow/tools/toolchains/remote_config/containers.bzl @@ -5,8 +5,9 @@ container_digests = { # TF now uses only this container "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821", # JAX manylinux2014 configs. - "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645", - "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb", + "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:b112c0c77d4172fc025420938f13ea83f3ad480c01778e743a201e5e3f4710e1", + "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:afe68c3448734cb07b16005fd9ed47d19533eb8bf5acd92863735ce24766b93b", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", @@ -98,6 +99,13 @@ containers = { "digest": container_digests["cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu18.04-manylinux2010-multipython. "rocm-ubuntu18.04-manylinux2010-multipython": { "registry": "gcr.io", diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 65074788800b78..631771d1e09f34 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -150,9 +150,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "88e0158aff1e1498e34dfcaf08d948a73a3246a04fe96e548da71f6b9245a009", - strip_prefix = "XNNPACK-c7e7cde37615a81a529c326aa278bfab4cd6fe5a", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/c7e7cde37615a81a529c326aa278bfab4cd6fe5a.zip"), + sha256 = "ca829b6486d7dcc0a63eae9d5d5be21dcb542e6601af4cada17b9d5f7d5fafb7", + strip_prefix = "XNNPACK-0cbbe74a16e6ca11acf8484ccac85f620336dea4", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/0cbbe74a16e6ca11acf8484ccac85f620336dea4.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -172,9 +172,9 @@ def _tf_repositories(): tf_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-959002f82d7962a473d8bf301845f2af720e0aa4", - sha256 = "a0f53ccfb477c57753c595df02bf79ed67bf092fd9a5c61ec5b8992b81bc1e65", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8bf301845f2af720e0aa4.zip"), + strip_prefix = "cpuinfo-ef634603954d88d2643d5809011288b890ac126e", + sha256 = "e07512a11e1c71687359a133f49d60583d7465b737fe5dbe11f461c9aaa72a2b", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/ef634603954d88d2643d5809011288b890ac126e.zip"), ) tf_http_archive( @@ -186,6 +186,14 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.9.zip"), ) + tf_http_archive( + name = "cutlass_archive", + build_file = "//third_party:cutlass.BUILD", + sha256 = "ea1b7f96919460a5d80b09c1b246652539a8605600b2be4cccc02c254bccbe50", + strip_prefix = "cutlass-5783d6dbd0c34032371cce2bd999fc76007520d7", + urls = tf_mirror_urls("https://github.com/chsigg/cutlass/archive/5783d6dbd0c34032371cce2bd999fc76007520d7.tar.gz"), + ) + tf_http_archive( name = "mkl_dnn_v1", build_file = "//third_party/mkl_dnn:mkldnn_v1.BUILD", @@ -589,6 +597,16 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/google/pprof/archive/83db2b799d1f74c40857232cb5eb4c60379fe6c2.tar.gz"), ) + # The CUDA 11 toolkit ships with CUB. We should be able to delete this rule + # once TF drops support for CUDA 10. + tf_http_archive( + name = "cub_archive", + build_file = "//third_party:cub.BUILD", + sha256 = "162514b3cc264ac89d91898b58450190b8192e2af1142cf8ccac2d59aa160dda", + strip_prefix = "cub-1.9.9", + urls = tf_mirror_urls("https://github.com/NVlabs/cub/archive/1.9.9.zip"), + ) + tf_http_archive( name = "nvtx_archive", build_file = "//third_party:nvtx.BUILD", diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD new file mode 100644 index 00000000000000..923d2f044c395a --- /dev/null +++ b/third_party/cutlass.BUILD @@ -0,0 +1,24 @@ +# Description: +# CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance +# matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT + +exports_files(["LICENSE.txt"]) + +filegroup( + name = "cutlass_header_files", + srcs = glob([ + "include/**", + ]), +) + +cc_library( + name = "cutlass", + hdrs = [":cutlass_header_files"], + strip_include_prefix = "/include", +) diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index 1aa9b2ff2d00ba..a0b943d7a9487b 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -2,12 +2,20 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +# _FLATBUFFERS_GIT_COMMIT / _FLATBUFFERS_SHA256 were added due to an urgent change being made to +# Flatbuffers that needed to be updated in order for Flatbuffers/TfLite be compatible with Android +# API level >= 23. They can be removed next flatbuffers offical release / update. +_FLATBUFFERS_GIT_COMMIT = "7d6d99c6befa635780a4e944d37ebfd58e68a108" + +# curl -L https://github.com/google/flatbuffers/archive/<_FLATBUFFERS_GIT_COMMIT>.tar.gz | shasum -a 256 +_FLATBUFFERS_SHA256 = "d27761f6b2fb1017ec00ed317a7b98cb7aed86b81d90528b498fb17ec13579a1" + def repo(): tf_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-23.5.26", - sha256 = "1cce06b17cddd896b6d73cc047e36a254fb8df4d7ea18a46acf16c4c0cd3f3f3", - urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/v23.5.26.tar.gz"), + strip_prefix = "flatbuffers-%s" % _FLATBUFFERS_GIT_COMMIT, + sha256 = _FLATBUFFERS_SHA256, + urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/%s.tar.gz" % _FLATBUFFERS_GIT_COMMIT), build_file = "//third_party/flatbuffers:flatbuffers.BUILD", system_build_file = "//third_party/flatbuffers:BUILD.system", link_files = { diff --git a/third_party/gemmlowp/workspace.bzl b/third_party/gemmlowp/workspace.bzl index b98035569852e2..884f707719a623 100644 --- a/third_party/gemmlowp/workspace.bzl +++ b/third_party/gemmlowp/workspace.bzl @@ -7,8 +7,8 @@ def repo(): # Attention: tools parse and update these lines. # LINT.IfChange - GEMMLOWP_COMMIT = "e844ffd17118c1e17d94e1ba4354c075a4577b88" - GEMMLOWP_SHA256 = "522b7a82d920ebd0c4408a5365866a40b81d1c0d60b2369011d315cca03c6476" + GEMMLOWP_COMMIT = "16e8662c34917be0065110bfcd9cc27d30f52fdf" + GEMMLOWP_SHA256 = "7dc418717c8456473fac4ff2288b71057e3dcb72894524c734a4362cdb51fa8b" # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/gemmlowp.cmake) tf_http_archive( diff --git a/third_party/gif_fix_image_counter.patch b/third_party/gif_fix_image_counter.patch index 1d72f75d6e80f4..2184e18af1b435 100644 --- a/third_party/gif_fix_image_counter.patch +++ b/third_party/gif_fix_image_counter.patch @@ -1,5 +1,5 @@ diff --git a/dgif_lib.c b/dgif_lib.c -index 82fc097..c6700a9 100644 +index 82fc097..214a0e7 100644 --- a/dgif_lib.c +++ b/dgif_lib.c @@ -810,7 +810,8 @@ DGifSetupDecompress(GifFileType *GifFile) @@ -12,7 +12,7 @@ index 82fc097..c6700a9 100644 } BitsPerPixel = CodeSize; -@@ -1118,6 +1119,28 @@ DGifBufferedInput(GifFileType *GifFile, GifByteType *Buf, GifByteType *NextByte) +@@ -1118,6 +1119,31 @@ DGifBufferedInput(GifFileType *GifFile, GifByteType *Buf, GifByteType *NextByte) return GIF_OK; } @@ -29,6 +29,9 @@ index 82fc097..c6700a9 100644 + if (GifFile->SavedImages[GifFile->ImageCount].RasterBits != NULL) { + free(GifFile->SavedImages[GifFile->ImageCount].RasterBits); + } ++ if (GifFile->SavedImages[GifFile->ImageCount].ImageDesc.ColorMap != NULL) { ++ GifFreeMapObject(GifFile->SavedImages[GifFile->ImageCount].ImageDesc.ColorMap); ++ } + + // Realloc array according to the new image counter. + SavedImage *correct_saved_images = (SavedImage *)reallocarray( @@ -41,7 +44,7 @@ index 82fc097..c6700a9 100644 /****************************************************************************** This routine reads an entire GIF into core, hanging all its state info off the GifFileType pointer. Call DGifOpenFileName() or DGifOpenFileHandle() -@@ -1148,17 +1171,20 @@ DGifSlurp(GifFileType *GifFile) +@@ -1148,17 +1174,20 @@ DGifSlurp(GifFileType *GifFile) /* Allocate memory for the image */ if (sp->ImageDesc.Width <= 0 || sp->ImageDesc.Height <= 0 || sp->ImageDesc.Width > (INT_MAX / sp->ImageDesc.Height)) { @@ -62,7 +65,7 @@ index 82fc097..c6700a9 100644 return GIF_ERROR; } -@@ -1177,13 +1203,17 @@ DGifSlurp(GifFileType *GifFile) +@@ -1177,13 +1206,17 @@ DGifSlurp(GifFileType *GifFile) j += InterlacedJumps[i]) { if (DGifGetLine(GifFile, sp->RasterBits+j*sp->ImageDesc.Width, diff --git a/third_party/gloo/BUILD b/third_party/gloo/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/gloo/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/gloo/gloo.BUILD b/third_party/gloo/gloo.BUILD new file mode 100644 index 00000000000000..68ba4e3610da70 --- /dev/null +++ b/third_party/gloo/gloo.BUILD @@ -0,0 +1,97 @@ +# Description: +# Gloo is a collective communications library + +load("//third_party/bazel_skylib/rules:expand_template.bzl", "expand_template") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +substitions = { + "@GLOO_VERSION_MAJOR@": "9999", + "@GLOO_VERSION_MINOR@": "0", + "@GLOO_VERSION_PATCH@": "0", + "#cmakedefine01 GLOO_USE_CUDA": "#define GLOO_USE_CUDA 0", + "#cmakedefine01 GLOO_USE_NCCL": "#define GLOO_USE_NCCL 0", + "#cmakedefine01 GLOO_USE_ROCM": "#define GLOO_USE_ROCM 0", + "#cmakedefine01 GLOO_USE_RCCL": "#define GLOO_USE_RCCL 0", + "#cmakedefine01 GLOO_USE_REDIS": "#define GLOO_USE_REDIS 0", + "#cmakedefine01 GLOO_USE_IBVERBS": "#define GLOO_USE_IBVERBS 0", + "#cmakedefine01 GLOO_USE_MPI": "#define GLOO_USE_MPI 0", + "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "#define GLOO_HAVE_TRANSPORT_TCP 1", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "#define GLOO_HAVE_TRANSPORT_TCP_TLS 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "#define GLOO_HAVE_TRANSPORT_IBVERBS 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_UV": "#define GLOO_HAVE_TRANSPORT_UV 0", + "#cmakedefine01 GLOO_USE_AVX": "#define GLOO_USE_AVX __AVX__", +} + +expand_template( + name = "config", + out = "gloo/config.h", + substitutions = substitions, + template = "gloo/config.h.in", +) + +cc_library( + name = "gloo", + srcs = glob( + [ + "gloo/*.cc", + "gloo/common/*.cc", + "gloo/transport/*.cc", + ], + exclude = [ + "gloo/common/linux.cc", + "gloo/common/win.cc", + "gloo/cuda*.cc", + ], + ) + [ + "gloo/rendezvous/context.cc", + "gloo/rendezvous/file_store.cc", + "gloo/rendezvous/hash_store.cc", + "gloo/rendezvous/prefix_store.cc", + "gloo/rendezvous/store.cc", + ] + select({ + "@local_tsl//tsl:macos": [], + "@local_tsl//tsl:windows": [], + "//conditions:default": [ + "gloo/common/linux.cc", + ], + }), + copts = [ + "-fexceptions", + "-Wno-unused-variable", + ], + includes = ["."], + textual_hdrs = glob( + [ + "gloo/*.h", + "gloo/common/*.h", + "gloo/transport/*.h", + ], + exclude = [ + "gloo/cuda*.h", + "gloo/common/win.h", + ], + ) + [ + "gloo/config.h", + "gloo/rendezvous/context.h", + "gloo/rendezvous/file_store.h", + "gloo/rendezvous/hash_store.h", + "gloo/rendezvous/prefix_store.h", + "gloo/rendezvous/store.h", + ], +) + +cc_library( + name = "transport_tcp", + srcs = glob(["gloo/transport/tcp/*.cc"]), + hdrs = glob(["gloo/transport/tcp/*.h"]), + copts = ["-fexceptions"], + deps = [":gloo"], +) diff --git a/third_party/gloo/workspace.bzl b/third_party/gloo/workspace.bzl new file mode 100644 index 00000000000000..ede168395acdc5 --- /dev/null +++ b/third_party/gloo/workspace.bzl @@ -0,0 +1,17 @@ +"""Provides the repository macro to import Gloo.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports Gloo.""" + + GLOO_COMMIT = "5354032ea08eadd7fc4456477f7f7c6308818509" + GLOO_SHA256 = "5759a06e6c8863c58e8ceadeb56f7c701fec89b2559ba33a103a447207bf69c7" + + tf_http_archive( + name = "gloo", + sha256 = GLOO_SHA256, + strip_prefix = "gloo-{commit}".format(commit = GLOO_COMMIT), + urls = tf_mirror_urls("https://github.com/facebookincubator/gloo/archive/{commit}.tar.gz".format(commit = GLOO_COMMIT)), + build_file = "//third_party/gloo:gloo.BUILD", + ) diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index b7d98ef2581157..afd6380b0ac203 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -23,6 +23,7 @@ """ import os import os.path +import platform import subprocess import sys @@ -38,6 +39,10 @@ class ConfigError(Exception): pass +def _is_windows(): + return platform.system() == "Windows" + + def check_cuda_lib(path, check_soname=True): """Tests if a library exists on disk and whether its soname matches the filename. @@ -52,7 +57,7 @@ def check_cuda_lib(path, check_soname=True): if not os.path.isfile(path): raise ConfigError("No library found under: " + path) objdump = which("objdump") - if check_soname and objdump is not None: + if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") output = [line for line in output.splitlines() if "SONAME" in line] diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 81e54ad431fccf..0da1d7b58f4bb0 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -45,10 +45,11 @@ import pipes # Template values set by cuda_autoconf. CPU_COMPILER = ('%{cpu_compiler}') -GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') +HOST_COMPILER_PATH = ('%{host_compiler_path}') NVCC_PATH = '%{nvcc_path}' -PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) +PREFIX_DIR = os.path.dirname(HOST_COMPILER_PATH) +USE_CLANG_COMPILER = '%{use_clang_compiler}' NVCC_VERSION = '%{cuda_version}' def Log(s): @@ -253,13 +254,23 @@ def InvokeNvcc(argv, log=False): # Force C++17 dialect (note, everything in just one string!) nvccopts += ' --std c++17 ' nvccopts += fatbin_options + # The option `-allow-unsupported-compiler` is required for the combination of + # NVCC+clang compilers. + # The following message appears if this option is not provided: + # unsupported clang version! clang version must be less than 16 and greater + # than 3.2 . The nvcc flag '-allow-unsupported-compiler' can be used + # to override this version check; however, using an unsupported host compiler + # may cause compilation failure or incorrect run time execution. + # Use at your own risk. + if USE_CLANG_COMPILER: + nvccopts += ' -allow-unsupported-compiler --expt-extended-lambda --expt-relaxed-constexpr ' if depfiles: # Generate the dependency file depfile = depfiles[0] cmd = (NVCC_PATH + ' ' + nvccopts + ' --compiler-options "' + host_compiler_options + '"' + - ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' --compiler-bindir=' + HOST_COMPILER_PATH + ' -I .' + ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile) if log: Log(cmd) @@ -269,7 +280,7 @@ def InvokeNvcc(argv, log=False): cmd = (NVCC_PATH + ' ' + nvccopts + ' --compiler-options "' + host_compiler_options + ' -fPIC"' + - ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' --compiler-bindir=' + HOST_COMPILER_PATH + ' -I .' + ' -x cu ' + opt + includes + ' -c ' + srcs + out) diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 8fb22313010a45..77ec948af32c6e 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -86,8 +86,8 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - #if args.fno_canonical_system_headers: - # opts += ' -fno-canonical-system-headers' + if args.fno_canonical_system_headers: + opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl new file mode 100644 index 00000000000000..c46e09484fdfad --- /dev/null +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows. + +DESCRIPTION: + This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc +""" + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import tempfile + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('%{cpu_compiler}') +GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') + +NVCC_PATH = '%{nvcc_path}' +NVCC_VERSION = '%{cuda_version}' +NVCC_TEMP_DIR = "%{nvcc_tmp_dir}" + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from options. + + Args: + option: The option whose value to extract. + + Returns: + 1. A list of values, either directly following the option, + (eg., /opt val1 val2) or values collected from multiple occurrences of + the option (eg., /opt val1 /opt val2). + 2. The leftover options. + """ + + parser = ArgumentParser(prefix_chars='-/') + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-/').replace('-', '_') + args, leftover = parser.parse_known_args(argv) + if args and vars(args)[option]: + return (sum(vars(args)[option], []), leftover) + return ([], leftover) + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + 1. The string that can be passed directly to nvcc. + 2. The leftover options. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, leftover = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return (['--' + a for a in options], leftover) + return ([], leftover) + + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling os.system('nvcc ' + args) + """ + + src_files = [f for f in argv if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + if len(src_files) == 0: + raise Error('No source files found for cuda compilation.') + + out_file = [ f for f in argv if f.startswith('/Fo') ] + if len(out_file) != 1: + raise Error('Please specify exactly one output file for cuda compilation.') + out = ['-o', out_file[0][len('/Fo'):]] + + nvcc_compiler_options, argv = GetNvccOptions(argv) + + opt_option, argv = GetOptionValue(argv, '/O') + opt = ['-g'] + if (len(opt_option) > 0 and opt_option[0] != 'd'): + opt = ['-O2'] + + include_options, argv = GetOptionValue(argv, '/I') + includes = ["-I " + include for include in include_options] + + defines, argv = GetOptionValue(argv, '/D') + defines = [ + '-D' + define + for define in defines + if 'BAZEL_CURRENT_REPOSITORY' not in define + ] + + undefines, argv = GetOptionValue(argv, '/U') + undefines = ['-U' + define for define in undefines] + + fatbin_options, argv = GetOptionValue(argv, '-Xcuda-fatbinary') + fatbin_options = ['--fatbin-options=' + option for option in fatbin_options] + + # The rest of the unrecognized options should be passed to host compiler + host_compiler_options = [option for option in argv if option not in (src_files + out_file)] + + m_options = ["-m64"] + + nvccopts = ['-D_FORCE_INLINES'] + compute_capabilities, argv = GetOptionValue(argv, "--cuda-gpu-arch") + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=sm_%s"' % (capability, capability) + ] + compute_capabilities, argv = GetOptionValue(argv, '--cuda-include-ptx') + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=compute_%s"' % (capability, capability) + ] + _, argv = GetOptionValue(argv, '--no-cuda-include-ptx') + + # nvcc doesn't respect the INCLUDE and LIB env vars from MSVC, + # so we explicity specify the system include paths and library search paths. + if 'INCLUDE' in os.environ: + nvccopts += [('--system-include="%s"' % p) for p in os.environ['INCLUDE'].split(";")] + if 'LIB' in os.environ: + nvccopts += [('--library-path="%s"' % p) for p in os.environ['LIB'].split(";")] + + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += m_options + nvccopts += fatbin_options + nvccopts += ['--compiler-options=' + ",".join(host_compiler_options)] + nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files + # Specify a unique temp directory for nvcc to generate intermediate files, + # then Bazel can ignore files under NVCC_TEMP_DIR during dependency check + # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver + # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists. + if os.path.isfile(NVCC_TEMP_DIR): + os.remove(NVCC_TEMP_DIR) + if not os.path.exists(NVCC_TEMP_DIR): + os.makedirs(NVCC_TEMP_DIR) + # Provide a unique dir for each compiling action to avoid conflicts. + tempdir = tempfile.mkdtemp(dir = NVCC_TEMP_DIR) + nvccopts += ['--keep', '--keep-dir', tempdir] + # Force C++17 dialect (note, everything in just one string!) + nvccopts += ['--std c++17'] + if log: + Log([NVCC_PATH] + nvccopts) + + # Store command line options in a file to avoid hitting the character limit. + optsfile = tempfile.NamedTemporaryFile(mode='w', dir=tempdir, delete=False) + optsfile.write("\n".join(nvccopts)) + optsfile.close() + + proc = subprocess.Popen([NVCC_PATH, "--options-file", optsfile.name], + stdout=sys.stdout, + stderr=sys.stderr, + env=os.environ.copy(), + shell=True) + proc.wait() + return proc.returncode + +def ExpandParamsFileForArgv(): + new_argv = [] + for arg in sys.argv: + if arg.startswith("@"): + with open(arg.strip("@")) as f: + new_argv.extend([l.strip() for l in f.readlines()]) + else: + new_argv.append(arg) + + sys.argv = new_argv + +def ProcessFlagForCommandFile(flag): + if flag.startswith("/D") or flag.startswith("-D"): + # We need to re-escape /DFOO="BAR" as /DFOO=\"BAR\", so that we get + # `#define FOO "BAR"` after expansion as a string literal define + if flag.endswith('"') and not flag.endswith('\\"'): + flag = '\\"'.join(flag.split('"', 1)) + flag = '\\"'.join(flag.rsplit('"', 1)) + return flag + return flag + +def main(): + ExpandParamsFileForArgv() + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log')) + and not flag.startswith(('-nvcc_options'))] + output = [flag for flag in cpu_compiler_flags if flag.startswith("/Fo")] + + # Store command line options in a file to avoid hitting the character limit. + if len(output) == 1: + commandfile_path = output[0][3:] + ".msvc_params" + commandfile = open(commandfile_path, "w") + cpu_compiler_flags = [ProcessFlagForCommandFile(flag) for flag in cpu_compiler_flags] + commandfile.write("\n".join(cpu_compiler_flags)) + commandfile.close() + return subprocess.call([CPU_COMPILER, "@" + commandfile_path]) + else: + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 700e040a88eeca..90a18b90de048c 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -61,23 +61,23 @@ cuda_header_library( cc_library( name = "cudart_static", - srcs = ["cuda/lib/libcudart_static.a"], + srcs = ["cuda/lib/%{cudart_static_lib}"], linkopts = [ "-ldl", - "-lrt", "-lpthread", + %{cudart_static_linkopt} ], ) cc_library( name = "cuda_driver", - srcs = ["cuda/lib/libcuda.so"], + srcs = ["cuda/lib/%{cuda_driver_lib}"], ) cc_library( name = "cudart", - srcs = glob(["cuda/lib/libcudart.so.*"]), - data = glob(["cuda/lib/libcudart.so.*"]), + srcs = ["cuda/lib/%{cudart_lib}"], + data = ["cuda/lib/%{cudart_lib}"], linkstatic = 1, ) @@ -128,30 +128,30 @@ cuda_header_library( cc_library( name = "cublas", - srcs = glob(["cuda/lib/libcublas.so.*"]), - data = glob(["cuda/lib/libcublas.so.*"]), + srcs = ["cuda/lib/%{cublas_lib}"], + data = ["cuda/lib/%{cublas_lib}"], linkstatic = 1, ) cc_library( name = "cublasLt", - srcs = glob(["cuda/lib/libcublasLt.so.*"]), - data = glob(["cuda/lib/libcublasLt.so.*"]), + srcs = ["cuda/lib/%{cublasLt_lib}"], + data = ["cuda/lib/%{cublasLt_lib}"], linkstatic = 1, ) cc_library( name = "cusolver", - srcs = glob(["cuda/lib/libcusolver.so.*"]), - data = glob(["cuda/lib/libcusolver.so.*"]), + srcs = ["cuda/lib/%{cusolver_lib}"], + data = ["cuda/lib/%{cusolver_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) cc_library( name = "cudnn", - srcs = glob(["cuda/lib/libcudnn.so.*"]), - data = glob(["cuda/lib/libcudnn.so.*"]), + srcs = ["cuda/lib/%{cudnn_lib}"], + data = ["cuda/lib/%{cudnn_lib}"], linkstatic = 1, ) @@ -165,15 +165,15 @@ cc_library( cc_library( name = "cufft", - srcs = glob(["cuda/lib/libcufft.so.*"]), - data = glob(["cuda/lib/libcufft.so.*"]), + srcs = ["cuda/lib/%{cufft_lib}"], + data = ["cuda/lib/%{cufft_lib}"], linkstatic = 1, ) cc_library( name = "curand", - srcs = glob(["cuda/lib/libcurand.so.*"]), - data = glob(["cuda/lib/libcurand.so.*"]), + srcs = ["cuda/lib/%{curand_lib}"], + data = ["cuda/lib/%{curand_lib}"], linkstatic = 1, ) @@ -192,7 +192,7 @@ cc_library( alias( name = "cub_headers", - actual = ":cuda_headers", + actual = "%{cub_actual}", ) cuda_header_library( @@ -213,13 +213,13 @@ cuda_header_library( cc_library( name = "cupti_dsos", - data = glob(["cuda/lib/libcupti.so.*"]), + data = ["cuda/lib/%{cupti_lib}"], ) cc_library( name = "cusparse", - srcs = glob(["cuda/lib/libcusparse.so.*"]), - data = glob(["cuda/lib/libcusparse.so.*"]), + srcs = ["cuda/lib/%{cusparse_lib}"], + data = ["cuda/lib/%{cusparse_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl new file mode 100644 index 00000000000000..dee0e898d9ae7a --- /dev/null +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -0,0 +1,238 @@ +load(":build_defs.bzl", "cuda_header_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cuda_header_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-include", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + "cuda/include", + ], +) + +cc_import( + name = "cudart_static", + # /WHOLEARCHIVE:cudart_static.lib will cause a + # "Internal error during CImplib::EmitThunk" error. + # Treat this library as interface library to avoid being whole archived when + # linking a DLL that depends on this. + # TODO(pcloudy): Remove this rule after b/111278841 is resolved. + interface_library = "cuda/lib/%{cudart_static_lib}", + system_provided = 1, +) + +cc_import( + name = "cuda_driver", + interface_library = "cuda/lib/%{cuda_driver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudart", + interface_library = "cuda/lib/%{cudart_lib}", + system_provided = 1, +) + +cuda_header_library( + name = "cublas_headers", + hdrs = [":cublas-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cublas/include"], + strip_include_prefix = "cublas/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusolver_headers", + hdrs = [":cusolver-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusolver/include"], + strip_include_prefix = "cusolver/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cufft_headers", + hdrs = [":cufft-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cufft/include"], + strip_include_prefix = "cufft/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusparse_headers", + hdrs = [":cusparse-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusparse/include"], + strip_include_prefix = "cusparse/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "curand_headers", + hdrs = [":curand-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["curand/include"], + strip_include_prefix = "curand/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cublas", + interface_library = "cuda/lib/%{cublas_lib}", + system_provided = 1, +) + +cc_import( + name = "cublasLt", + interface_library = "cuda/lib/%{cublasLt_lib}", + system_provided = 1, +) + +cc_import( + name = "cusolver", + interface_library = "cuda/lib/%{cusolver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudnn", + interface_library = "cuda/lib/%{cudnn_lib}", + system_provided = 1, +) + +cc_library( + name = "cudnn_header", + hdrs = [":cudnn-include"], + include_prefix = "third_party/gpus/cudnn", + strip_include_prefix = "cudnn/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cufft", + interface_library = "cuda/lib/%{cufft_lib}", + system_provided = 1, +) + +cc_import( + name = "curand", + interface_library = "cuda/lib/%{curand_lib}", + system_provided = 1, +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = "%{cub_actual}", +) + +cuda_header_library( + name = "cupti_headers", + hdrs = [":cuda-extras"], + include_prefix = "third_party/gpus", + includes = ["cuda/extras/CUPTI/include/"], + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "nvml_headers", + hdrs = [":nvml"], + include_prefix = "third_party/gpus", + includes = ["cuda/nvml/include/"], + deps = [":cuda_headers"], +) + +cc_import( + name = "cupti_dsos", + interface_library = "cuda/lib/%{cupti_lib}", + system_provided = 1, +) + +cc_import( + name = "cusparse", + interface_library = "cuda/lib/%{cusparse_lib}", + system_provided = 1, +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +%{copy_rules} diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 3f2b67632e1a67..e73e41a0c383a2 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -4,7 +4,8 @@ * `TF_NEED_CUDA`: Whether to enable building with CUDA. * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path - * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler. + * `TF_CUDA_CLANG`: Whether to use clang for C++ and Cuda compilation. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for both host and device code compilation if TF_CUDA_CLANG is 1. * `TF_SYSROOT`: The sysroot to use when compiling. @@ -26,14 +27,27 @@ """ load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") +load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", + "get_env_var", +) +load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", + "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", "err_out", "execute", "get_bash_bin", + "get_cpu_value", "get_host_environ", "get_python_bin", + "is_windows", "raw_exec", "read_dir", "realpath", @@ -82,7 +96,16 @@ def verify_build_defines(params): "host_compiler_warnings", "linker_bin_path", "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", "unfiltered_compile_flags", + "win_compiler_deps", ]: if ("%{" + param + "}") not in params: missing.append(param) @@ -96,13 +119,104 @@ def verify_build_defines(params): ".", ) +def _get_nvcc_tmp_dir_for_windows(repository_ctx): + """Return the Windows tmp directory for nvcc to generate intermediate source files.""" + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" + +def _get_msvc_compiler(repository_ctx): + vc_path = find_vc_path(repository_ctx) + return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") + +def _get_win_cuda_defines(repository_ctx): + """Return CROSSTOOL defines for Windows""" + + # If we are not on Windows, return fake vaules for Windows specific fields. + # This ensures the CROSSTOOL file parser is happy. + if not is_windows(repository_ctx): + return { + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + } + + vc_path = find_vc_path(repository_ctx) + if not vc_path: + auto_configure_fail( + "Visual C++ build tools not found on your machine." + + "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using", + ) + return {} + + env = setup_vc_env_vars(repository_ctx, vc_path) + escaped_paths = escape_string(env["PATH"]) + escaped_include_paths = escape_string(env["INCLUDE"]) + escaped_lib_paths = escape_string(env["LIB"]) + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + + msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat" + msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( + "\\", + "/", + ) + msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( + "\\", + "/", + ) + msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( + "\\", + "/", + ) + + # nvcc will generate some temporary source files under %{nvcc_tmp_dir} + # The generated files are guaranteed to have unique name, so they can share + # the same tmp directory + escaped_cxx_include_directories = [ + _get_nvcc_tmp_dir_for_windows(repository_ctx), + "C:\\\\botcode\\\\w", + ] + for path in escaped_include_paths.split(";"): + if path: + escaped_cxx_include_directories.append(path) + + return { + "%{msvc_env_tmp}": escaped_tmp_dir, + "%{msvc_env_path}": escaped_paths, + "%{msvc_env_include}": escaped_include_paths, + "%{msvc_env_lib}": escaped_lib_paths, + "%{msvc_cl_path}": msvc_cl_path, + "%{msvc_ml_path}": msvc_ml_path, + "%{msvc_link_path}": msvc_link_path, + "%{msvc_lib_path}": msvc_lib_path, + "%{cxx_builtin_include_directories}": to_list_of_strings( + escaped_cxx_include_directories, + ), + } + # TODO(dzc): Once these functions have been factored out of Bazel's # cc_configure.bzl, load them from @bazel_tools instead. # BEGIN cc_configure common functions. -def find_cc(repository_ctx): +def find_cc(repository_ctx, use_cuda_clang): """Find the C++ compiler.""" + if is_windows(repository_ctx): + return _get_msvc_compiler(repository_ctx) - if _use_cuda_clang(repository_ctx): + if use_cuda_clang: target_cc_name = "clang" cc_path_envvar = _CLANG_CUDA_COMPILER_PATH if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG): @@ -251,9 +365,10 @@ def _cuda_include_path(repository_ctx, cuda_config): Returns: A list of the gcc host compiler include directories. """ - nvcc_path = repository_ctx.path( - "%s/bin/nvcc" % cuda_config.cuda_toolkit_path, - ) + nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % ( + cuda_config.cuda_toolkit_path, + ".exe" if cuda_config.cpu_value == "Windows" else "", + )) # The expected exit code of this command is non-zero. Bazel remote execution # only caches commands with zero exit code. So force a zero exit code. @@ -314,6 +429,10 @@ def matches_version(environ_version, detected_version): return False return True +_NVCC_VERSION_PREFIX = "Cuda compilation tools, release " + +_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" + def compute_capabilities(repository_ctx): """Returns a list of strings representing cuda compute capabilities. @@ -356,11 +475,12 @@ def compute_capabilities(repository_ctx): return capabilities -def lib_name(base_name, version = None, static = False): +def lib_name(base_name, cpu_value, version = None, static = False): """Constructs the platform-specific name of a library. Args: base_name: The name of the library, such as "cudart" + cpu_value: The name of the host operating system. version: The version of the library. static: True the library is static or False if it is a shared object. @@ -368,20 +488,29 @@ def lib_name(base_name, version = None, static = False): The platform-specific name of the library. """ version = "" if not version else "." + version - if static: - return "lib%s.a" % base_name - return "lib%s.so%s" % (base_name, version) + if cpu_value in ("Linux", "FreeBSD"): + if static: + return "lib%s.a" % base_name + return "lib%s.so%s" % (base_name, version) + elif cpu_value == "Windows": + return "%s.lib" % base_name + elif cpu_value == "Darwin": + if static: + return "lib%s.a" % base_name + return "lib%s%s.dylib" % (base_name, version) + else: + auto_configure_fail("Invalid cpu_value: %s" % cpu_value) -def _lib_path(lib, basedir, version, static): - file_name = lib_name(lib, version, static) +def _lib_path(lib, cpu_value, basedir, version, static): + file_name = lib_name(lib, cpu_value, version, static) return "%s/%s" % (basedir, file_name) def _should_check_soname(version, static): return version and not static -def _check_cuda_lib_params(lib, basedir, version, static = False): +def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False): return ( - _lib_path(lib, basedir, version, static), + _lib_path(lib, cpu_value, basedir, version, static), _should_check_soname(version, static), ) @@ -401,6 +530,8 @@ def _check_cuda_libs(repository_ctx, script_path, libs): all_paths = [path for path, _ in libs] checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines() + # Filter out empty lines from splitting on '\r\n' on Windows + checked_paths = [path for path in checked_paths if len(path) > 0] if all_paths != checked_paths: auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths)) @@ -418,62 +549,86 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): Returns: Map of library names to structs of filename and path. """ + cpu_value = cuda_config.cpu_value + stub_dir = "" if is_windows(repository_ctx) else "/stubs" + check_cuda_libs_params = { "cuda": _check_cuda_lib_params( "cuda", - cuda_config.config["cuda_library_dir"] + "/stubs", + cpu_value, + cuda_config.config["cuda_library_dir"] + stub_dir, version = None, + static = False, ), "cudart": _check_cuda_lib_params( "cudart", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, + static = False, ), "cudart_static": _check_cuda_lib_params( "cudart_static", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, static = True, ), "cublas": _check_cuda_lib_params( "cublas", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cublasLt": _check_cuda_lib_params( "cublasLt", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cusolver": _check_cuda_lib_params( "cusolver", + cpu_value, cuda_config.config["cusolver_library_dir"], cuda_config.cusolver_version, + static = False, ), "curand": _check_cuda_lib_params( "curand", + cpu_value, cuda_config.config["curand_library_dir"], cuda_config.curand_version, + static = False, ), "cufft": _check_cuda_lib_params( "cufft", + cpu_value, cuda_config.config["cufft_library_dir"], cuda_config.cufft_version, + static = False, ), "cudnn": _check_cuda_lib_params( "cudnn", + cpu_value, cuda_config.config["cudnn_library_dir"], cuda_config.cudnn_version, + static = False, ), "cupti": _check_cuda_lib_params( "cupti", + cpu_value, cuda_config.config["cupti_library_dir"], cuda_config.cupti_version, + static = False, ), "cusparse": _check_cuda_lib_params( "cusparse", + cpu_value, cuda_config.config["cusparse_library_dir"], cuda_config.cusparse_version, + static = False, ), } @@ -483,6 +638,10 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()} return paths +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "" if cpu_value == "Darwin" else "\"-lrt\"," + # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, # and nccl_configure.bzl. def find_cuda_config(repository_ctx, cuda_libraries): @@ -509,34 +668,37 @@ def _get_cuda_config(repository_ctx): cudart_version: The CUDA runtime version on the system. cudnn_version: The version of cuDNN on the system. compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. """ config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) + cpu_value = get_cpu_value(repository_ctx) toolkit_path = config["cuda_toolkit_path"] + is_windows = cpu_value == "Windows" cuda_version = config["cuda_version"].split(".") cuda_major = cuda_version[0] cuda_minor = cuda_version[1] - cuda_version = "%s.%s" % (cuda_major, cuda_minor) - cudnn_version = "%s" % config["cudnn_version"] + cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) + cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] if int(cuda_major) >= 11: # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability. if int(cuda_major) == 11: - cudart_version = "11.0" + cudart_version = "64_110" if is_windows else "11.0" cupti_version = cuda_version else: - cudart_version = "%s" % cuda_major + cudart_version = ("64_%s" if is_windows else "%s") % cuda_major cupti_version = cudart_version - cublas_version = "%s" % config["cublas_version"].split(".")[0] - cusolver_version = "%s" % config["cusolver_version"].split(".")[0] - curand_version = "%s" % config["curand_version"].split(".")[0] - cufft_version = "%s" % config["cufft_version"].split(".")[0] - cusparse_version = "%s" % config["cusparse_version"].split(".")[0] + cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0] + cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0] + curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0] + cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0] + cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0] elif (int(cuda_major), int(cuda_minor)) >= (10, 1): # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. # It changed from 'x.y' to just 'x' in CUDA 10.1. - cuda_lib_version = "%s" % cuda_major + cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major cudart_version = cuda_version cupti_version = cuda_version cublas_version = cuda_lib_version @@ -566,6 +728,7 @@ def _get_cuda_config(repository_ctx): cusparse_version = cusparse_version, cudnn_version = cudnn_version, compute_capabilities = compute_capabilities(repository_ctx), + cpu_value = cpu_value, config = config, ) @@ -611,6 +774,8 @@ error_gpu_disabled() """ def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + # Set up BUILD file for cuda/. _tpl( repository_ctx, @@ -625,6 +790,23 @@ def _create_dummy_repository(repository_ctx): repository_ctx, "cuda:BUILD", { + "%{cuda_driver_lib}": lib_name("cuda", cpu_value), + "%{cudart_static_lib}": lib_name( + "cudart_static", + cpu_value, + static = True, + ), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + "%{cudart_lib}": lib_name("cudart", cpu_value), + "%{cublas_lib}": lib_name("cublas", cpu_value), + "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), + "%{cusolver_lib}": lib_name("cusolver", cpu_value), + "%{cudnn_lib}": lib_name("cudnn", cpu_value), + "%{cufft_lib}": lib_name("cufft", cpu_value), + "%{curand_lib}": lib_name("curand", cpu_value), + "%{cupti_lib}": lib_name("cupti", cpu_value), + "%{cusparse_lib}": lib_name("cusparse", cpu_value), + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": """ filegroup(name="cuda-include") filegroup(name="cublas-include") @@ -643,9 +825,20 @@ filegroup(name="cudnn-include") repository_ctx.file("cuda/cuda/include/cublas.h") repository_ctx.file("cuda/cuda/include/cudnn.h") repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h") - repository_ctx.file("cuda/cuda/lib/libcuda.so") - repository_ctx.file("cuda/cuda/lib/libcudart_static.a") repository_ctx.file("cuda/cuda/nvml/include/nvml.h") + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) + repository_ctx.file( + "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), + ) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value)) # Set up cuda_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. @@ -709,7 +902,7 @@ def make_copy_files_rule(repository_ctx, name, srcs, outs): cmd = \"""%s \""", )""" % (name, "\n".join(outs), " && \\\n".join(cmds)) -def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): +def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None): """Returns a rule to recursively copy a directory. If exceptions is not None, it must be a list of files or directories in 'src_dir'; these will be excluded from copying. @@ -717,25 +910,39 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): src_dir = _norm_path(src_dir) out_dir = _norm_path(out_dir) outs = read_dir(repository_ctx, src_dir) + post_cmd = "" + if exceptions != None: + outs = [x for x in outs if not any([ + x.startswith(src_dir + "/" + y) + for y in exceptions + ])] outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] # '@D' already contains the relative path for a single file, see # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" + if exceptions != None: + for x in exceptions: + post_cmd += " ; rm -fR " + out_dir + "/" + x return """genrule( name = "%s", outs = [ %s ], - cmd = \"""cp -rLf "%s/." "%s/" \""", -)""" % (name, "\n".join(outs), src_dir, out_dir) + cmd = \"""cp -rLf "%s/." "%s/" %s\""", +)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd) def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" def _use_cuda_clang(repository_ctx): + # Returns the flag if we need to use clang both for C++ and Cuda. return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, "TF_NVCC_CLANG") + def _tf_sysroot(repository_ctx): return get_host_environ(repository_ctx, _TF_SYSROOT, "") @@ -752,6 +959,22 @@ def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): def _tpl_path(repository_ctx, filename): return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename)) +def _basename(repository_ctx, path_str): + """Returns the basename of a path of type string. + + This method is different from path.basename in that it also works if + the host platform is different from the execution platform + i.e. linux -> windows. + """ + + num_chars = len(path_str) + is_win = is_windows(repository_ctx) + for i in range(num_chars): + r_i = num_chars - 1 - i + if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/": + return path_str[r_i + 1:] + return path_str + def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" @@ -760,14 +983,15 @@ def _create_local_cuda_repository(repository_ctx): # can easily lead to a O(n^2) runtime in the number of labels. # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [ - "cuda:BUILD", "cuda:build_defs.bzl", "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", + "crosstool:windows/msvc_wrapper_for_nvcc.py", "crosstool:BUILD", "crosstool:cc_toolchain_config.bzl", "cuda:cuda_config.h", "cuda:cuda_config.py", ]} + tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD") cuda_config = _get_cuda_config(repository_ctx) @@ -879,7 +1103,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_lib_outs = [] for path in cuda_libs.values(): cuda_lib_srcs.append(path) - cuda_lib_outs.append("cuda/lib/" + path.rpartition("/")[-1]) + cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path)) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-lib", @@ -888,7 +1112,11 @@ def _create_local_cuda_repository(repository_ctx): )) # copy files mentioned in third_party/nccl/build_defs.bzl.tpl - bin_files = ["crt/link.stub", "bin2c", "fatbinary", "nvlink", "nvprune"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + bin_files = ( + ["crt/link.stub"] + + [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]] + ) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-bin", @@ -896,7 +1124,7 @@ def _create_local_cuda_repository(repository_ctx): outs = ["cuda/bin/" + f for f in bin_files], )) - # Select the headers based on the cuDNN version. + # Select the headers based on the cuDNN version (strip '64_' for Windows). cudnn_headers = ["cudnn.h"] if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8": cudnn_headers += [ @@ -937,15 +1165,33 @@ def _create_local_cuda_repository(repository_ctx): }, ) + cub_actual = "@cub_archive//:cub" + if int(cuda_config.cuda_version_major) >= 11: + cub_actual = ":cuda_headers" + repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], { + "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]), + "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), + "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), + "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), + "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), + "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), + "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), + "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]), + "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), + "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), + "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), + "%{cub_actual}": cub_actual, "%{copy_rules}": "\n".join(copy_rules), }, ) is_cuda_clang = _use_cuda_clang(repository_ctx) + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) tf_sysroot = _tf_sysroot(repository_ctx) should_download_clang = is_cuda_clang and _flag_enabled( @@ -956,7 +1202,7 @@ def _create_local_cuda_repository(repository_ctx): download_clang(repository_ctx, "crosstool/extra_tools") # Set up crosstool/ - cc = find_cc(repository_ctx) + cc = find_cc(repository_ctx, is_cuda_clang) cc_fullpath = cc if not should_download_clang else "crosstool/" + cc host_compiler_includes = get_cxx_inc_directories( @@ -993,7 +1239,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" - if is_cuda_clang: + if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ # Some parts of the codebase set -Werror and hit this warning, so @@ -1002,10 +1248,12 @@ def _create_local_cuda_repository(repository_ctx): """ cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes) cuda_defines["%{compiler_deps}"] = ":empty" + cuda_defines["%{win_compiler_deps}"] = ":empty" repository_ctx.file( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "", ) + repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") else: cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" cuda_defines["%{host_compiler_warnings}"] = "" @@ -1025,22 +1273,40 @@ def _create_local_cuda_repository(repository_ctx): # .d file - given that includes that are prefixed with "../" multiple # time quickly grow longer than the root of the tree, this can lead to # bazel's header check failing. - cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" + if not is_cuda_clang: + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" - nvcc_path = "%s/nvcc" % cuda_config.config["cuda_binary_dir"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext) cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files" wrapper_defines = { "%{cpu_compiler}": str(cc), "%{cuda_version}": cuda_config.cuda_version, "%{nvcc_path}": nvcc_path, - "%{gcc_host_compiler_path}": str(cc), + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": str(is_nvcc_and_clang), + "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), } repository_ctx.template( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"], wrapper_defines, ) + repository_ctx.file( + "crosstool/windows/msvc_wrapper_for_nvcc.bat", + content = "@echo OFF\n{} -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py %*".format( + get_python_bin(repository_ctx), + ), + ) + repository_ctx.template( + "crosstool/windows/msvc_wrapper_for_nvcc.py", + tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"], + wrapper_defines, + ) + + cuda_defines.update(_get_win_cuda_defines(repository_ctx)) verify_build_defines(cuda_defines) @@ -1171,12 +1437,28 @@ def _cuda_autoconf_impl(repository_ctx): repository_ctx.symlink(build_file, "BUILD") +# For @bazel_tools//tools/cpp:windows_cc_configure.bzl +_MSVC_ENVVARS = [ + "BAZEL_VC", + "BAZEL_VC_FULL_VERSION", + "BAZEL_VS", + "BAZEL_WINSDK_FULL_VERSION", + "VS90COMNTOOLS", + "VS100COMNTOOLS", + "VS110COMNTOOLS", + "VS120COMNTOOLS", + "VS140COMNTOOLS", + "VS150COMNTOOLS", + "VS160COMNTOOLS", +] + _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, _CLANG_CUDA_COMPILER_PATH, "TF_NEED_CUDA", "TF_CUDA_CLANG", + "TF_NVCC_CLANG", _TF_DOWNLOAD_CLANG, _CUDA_TOOLKIT_PATH, _CUDNN_INSTALL_PATH, @@ -1188,7 +1470,7 @@ _ENVIRONS = [ "TMP", "TMPDIR", "TF_CUDA_PATHS", -] +] + _MSVC_ENVVARS remote_cuda_configure = repository_rule( implementation = _create_local_cuda_repository, diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index 78292c7b40237a..b88694af5c014d 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -29,6 +29,8 @@ If TF_CUDA_PATHS is not specified, a OS specific default is used: Linux: /usr/local/cuda, /usr, and paths from 'ldconfig -p'. + Windows: CUDA_PATH environment variable, or + C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\* For backwards compatibility, some libraries also use alternative base directories from other environment variables if they are specified. List of @@ -54,6 +56,7 @@ import io import os import glob +import platform import re import subprocess import sys @@ -70,6 +73,18 @@ class ConfigError(Exception): pass +def _is_linux(): + return platform.system() == "Linux" + + +def _is_windows(): + return platform.system() == "Windows" + + +def _is_macos(): + return platform.system() == "Darwin" + + def _matches_version(actual_version, required_version): """Checks whether some version meets the requirements. @@ -119,6 +134,8 @@ def _cartesian_product(first, second): def _get_ld_config_paths(): """Returns all directories from 'ldconfig -p'.""" + if not _is_linux(): + return [] ldconfig_path = which("ldconfig") or "/sbin/ldconfig" output = subprocess.check_output([ldconfig_path, "-p"]) pattern = re.compile(".* => (.*)") @@ -139,6 +156,13 @@ def _get_default_cuda_paths(cuda_version): elif not "." in cuda_version: cuda_version = cuda_version + ".*" + if _is_windows(): + return [ + os.environ.get( + "CUDA_PATH", + "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v%s\\" % + cuda_version) + ] return ["/usr/local/cuda-%s" % cuda_version, "/usr/local/cuda", "/usr", "/usr/local/cudnn"] + _get_ld_config_paths() @@ -188,8 +212,14 @@ def _find_file(base_paths, relative_paths, filepattern): def _find_library(base_paths, library_name, required_version): """Returns first valid path to the requested library.""" - filepattern = ".".join(["lib" + library_name, "so"] + - required_version.split(".")[:1]) + "*" + if _is_windows(): + filepattern = library_name + ".lib" + elif _is_macos(): + filepattern = "%s*.dylib" % (".".join(["lib" + library_name] + + required_version.split(".")[:1])) + else: + filepattern = ".".join(["lib" + library_name, "so"] + + required_version.split(".")[:1]) + "*" return _find_file(base_paths, _library_paths(), filepattern) @@ -238,7 +268,7 @@ def get_nvcc_version(path): return match.group(1) return None - nvcc_name = "nvcc" + nvcc_name = "nvcc.exe" if _is_windows() else "nvcc" nvcc_path, nvcc_version = _find_versioned_file(base_paths, [ "", "bin", @@ -528,6 +558,14 @@ def _get_legacy_path(env_name, default=[]): return _list_from_env(env_name, default) +def _normalize_path(path): + """Returns normalized path, with forward slashes on Windows.""" + path = os.path.realpath(path) + if _is_windows(): + path = path.replace("\\", "/") + return path + + def find_cuda_config(): """Returns a dictionary of CUDA library and header file paths.""" libraries = [argv.lower() for argv in sys.argv[1:]] @@ -596,7 +634,7 @@ def find_cuda_config(): for k, v in result.items(): if k.endswith("_dir") or k.endswith("_path"): - result[k] = os.path.realpath(v) + result[k] = _normalize_path(v) return result diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 90464b07264101..520c9bce6c5265 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -198,6 +198,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") @@ -345,14 +347,14 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths = [ (name, _rocm_lib_paths(repository_ctx, name, path)) for name, path in [ - ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"), + ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), ("hipsparse", rocm_config.rocm_toolkit_path), - ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"), + ("roctracer64", rocm_config.rocm_toolkit_path), ("rocsolver", rocm_config.rocm_toolkit_path), ] ] @@ -694,7 +696,7 @@ def _create_local_rocm_repository(repository_ctx): rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", - "-D__HIP_PLATFORM_HCC__", + "-D__HIP_PLATFORM_AMD__", "-DEIGEN_USE_HIP", ]) @@ -729,7 +731,7 @@ def _create_local_rocm_repository(repository_ctx): "%{hipcc_env}": _hipcc_env(repository_ctx), "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", + "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), diff --git a/third_party/highwayhash/highwayhash.BUILD b/third_party/highwayhash/highwayhash.BUILD index 76f0c962ef8b8a..c24c987a276acd 100644 --- a/third_party/highwayhash/highwayhash.BUILD +++ b/third_party/highwayhash/highwayhash.BUILD @@ -286,6 +286,7 @@ cc_library( ":hh_portable", ":hh_types", ] + select({ + ":cpu_ppc": [":hh_vsx"], ":cpu_aarch64": [":hh_neon"], "//conditions:default": [ ":hh_avx2", diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 3597c8870d19ff..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,2483 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h b/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h ---- a/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h -+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h -@@ -1,42 +0,0 @@ --//===- MergeFunctionsIgnoringConst.h - Merge Functions ----------*- C++ -*-===// --// --// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. --// See https://llvm.org/LICENSE.txt for license information. --// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception --// --//===----------------------------------------------------------------------===// --// --// This pass transforms simple global variables that never have their address --// taken. If obviously true, it marks read/write globals as constant, deletes --// variables only stored to, etc. --// --//===----------------------------------------------------------------------===// -- --#ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONSIGNORINGCONST_H --#define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONSIGNORINGCONST_H -- --#include "llvm/IR/PassManager.h" -- --namespace llvm { -- --class Module; -- --/// Merge functions that differ by constants. --class MergeFuncIgnoringConstPass -- : public PassInfoMixin { -- bool PtrAuthEnabled = false; -- unsigned PtrAuthKey = 0; -- std::string MergeFuncSuffix = ".Tm"; -- --public: -- MergeFuncIgnoringConstPass() {} -- MergeFuncIgnoringConstPass(bool PtrAuthEnabled, unsigned PtrAuthKey, -- std::string Suffix) -- : PtrAuthEnabled(PtrAuthEnabled), PtrAuthKey(PtrAuthKey), -- MergeFuncSuffix(Suffix) {} -- PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); --}; -- --} // end namespace llvm -- --#endif // LLVM_TRANSFORMS_IPO_MERGEFUNCTIONSIGNORINGCONST_H -diff -ruN --strip-trailing-cr a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h ---- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h -+++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h -@@ -379,7 +379,6 @@ - /// But, we are still not able to compare operands of PHI nodes, since those - /// could be operands from further BBs we didn't scan yet. - /// So it's impossible to use dominance properties in general. --protected: - mutable DenseMap sn_mapL, sn_mapR; - - // The global state we will use -diff -ruN --strip-trailing-cr a/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h b/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h ---- a/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h -+++ b/llvm/include/llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h -@@ -1,58 +0,0 @@ --//===- FunctionComparatorIgnoringConst.h - Function Comparator --*- C++ -*-===// --// --// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. --// See https://llvm.org/LICENSE.txt for license information. --// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception --// --//===----------------------------------------------------------------------===// --// --// This file defines the FunctionComparatorIgnoringConst class which is used by --// the MergeFuncIgnoringConst pass for comparing functions. --// --//===----------------------------------------------------------------------===// -- --#ifndef LLVM_TRANSFORMS_UTILS_FUNCTIONCOMPARATORIGNORINGCONST_H --#define LLVM_TRANSFORMS_UTILS_FUNCTIONCOMPARATORIGNORINGCONST_H -- --#include "llvm/ADT/DenseMap.h" --#include "llvm/ADT/StringRef.h" --#include "llvm/IR/Attributes.h" --#include "llvm/IR/Instructions.h" --#include "llvm/IR/Operator.h" --#include "llvm/IR/ValueMap.h" --#include "llvm/Support/AtomicOrdering.h" --#include "llvm/Support/Casting.h" --#include "llvm/Transforms/Utils/FunctionComparator.h" --#include -- --namespace llvm { -- --/// FunctionComparatorIgnoringConst - Compares two functions to determine --/// whether or not they match when certain constants are ignored. --class FunctionComparatorIgnoringConst : public FunctionComparator { --public: -- FunctionComparatorIgnoringConst(const Function *F1, const Function *F2, -- GlobalNumberState *GN) -- : FunctionComparator(F1, F2, GN) {} -- -- int cmpOperandsIgnoringConsts(const Instruction *L, const Instruction *R, -- unsigned opIdx); -- -- int cmpBasicBlocksIgnoringConsts( -- const BasicBlock *BBL, const BasicBlock *BBR, -- const std::set> *InstOpndIndex = nullptr); -- -- int compareIgnoringConsts( -- const std::set> *InstOpndIndex = nullptr); -- -- int compareConstants(const Constant *L, const Constant *R) const { -- return cmpConstants(L, R); -- } -- --private: -- /// Scratch index for instruction in order during cmpOperandsIgnoringConsts. -- int Index = 0; --}; -- --} // end namespace llvm --#endif // LLVM_TRANSFORMS_UTILS_FUNCTIONCOMPARATORIGNORINGCONST_H -diff -ruN --strip-trailing-cr a/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h b/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h ---- a/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h -+++ b/llvm/include/llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h -@@ -1,29 +0,0 @@ --//===- MergeFunctionsIgnoringConst.h - Merge Functions ---------*- C++ -*-===// --// --// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. --// See https://llvm.org/LICENSE.txt for license information. --// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception --// --//===----------------------------------------------------------------------===// --// --// This file defines helpers used in the MergeFunctionsIgnoringConst. --// --//===----------------------------------------------------------------------===// -- --#ifndef LLVM_TRANSFORMS_UTILS_MERGEFUNCTIONSIGNORINGCONST_H --#define LLVM_TRANSFORMS_UTILS_MERGEFUNCTIONSIGNORINGCONST_H -- --#include "llvm/IR/IRBuilder.h" --#include "llvm/IR/Instructions.h" --#include "llvm/IR/Operator.h" -- --using namespace llvm; -- --bool isEligibleInstrunctionForConstantSharing(const Instruction *I); -- --bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx); -- --bool isEligibleFunction(Function *F); -- --Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy); --#endif // LLVM_TRANSFORMS_UTILS_MERGEFUNCTIONSIGNORINGCONST_H -diff -ruN --strip-trailing-cr a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp ---- a/llvm/lib/Passes/PassBuilder.cpp -+++ b/llvm/lib/Passes/PassBuilder.cpp -@@ -123,7 +123,6 @@ - #include "llvm/Transforms/IPO/LowerTypeTests.h" - #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h" - #include "llvm/Transforms/IPO/MergeFunctions.h" --#include "llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h" - #include "llvm/Transforms/IPO/ModuleInliner.h" - #include "llvm/Transforms/IPO/OpenMPOpt.h" - #include "llvm/Transforms/IPO/PartialInlining.h" -diff -ruN --strip-trailing-cr a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp ---- a/llvm/lib/Passes/PassBuilderPipelines.cpp -+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp -@@ -60,7 +60,6 @@ - #include "llvm/Transforms/IPO/LowerTypeTests.h" - #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h" - #include "llvm/Transforms/IPO/MergeFunctions.h" --#include "llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h" - #include "llvm/Transforms/IPO/ModuleInliner.h" - #include "llvm/Transforms/IPO/OpenMPOpt.h" - #include "llvm/Transforms/IPO/PartialInlining.h" -@@ -177,10 +176,6 @@ - "enable-merge-functions", cl::init(false), cl::Hidden, - cl::desc("Enable function merging as part of the optimization pipeline")); - --static cl::opt EnableMergeFuncIgnoringConst( -- "enable-merge-func-ignoring-const", cl::init(false), cl::Hidden, -- cl::desc("Enable function merger that ignores constants")); -- - static cl::opt EnablePostPGOLoopRotation( - "enable-post-pgo-loop-rotation", cl::init(true), cl::Hidden, - cl::desc("Run the loop rotation transformation after PGO instrumentation")); -@@ -1638,9 +1633,6 @@ - MPM.addPass(buildModuleOptimizationPipeline( - Level, ThinOrFullLTOPhase::ThinLTOPostLink)); - -- if (EnableMergeFuncIgnoringConst) -- MPM.addPass(MergeFuncIgnoringConstPass()); -- - // Emit annotation remarks. - addAnnotationRemarksPass(MPM); - -@@ -1966,9 +1958,6 @@ - - invokeFullLinkTimeOptimizationLastEPCallbacks(MPM, Level); - -- if (EnableMergeFuncIgnoringConst) -- MPM.addPass(MergeFuncIgnoringConstPass()); -- - // Emit annotation remarks. - addAnnotationRemarksPass(MPM); - -diff -ruN --strip-trailing-cr a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def ---- a/llvm/lib/Passes/PassRegistry.def -+++ b/llvm/lib/Passes/PassRegistry.def -@@ -87,7 +87,6 @@ - MODULE_PASS("lowertypetests", LowerTypeTestsPass()) - MODULE_PASS("metarenamer", MetaRenamerPass()) - MODULE_PASS("mergefunc", MergeFunctionsPass()) --MODULE_PASS("mergefunc-ignoring-const", MergeFuncIgnoringConstPass()) - MODULE_PASS("name-anon-globals", NameAnonGlobalPass()) - MODULE_PASS("no-op-module", NoOpModulePass()) - MODULE_PASS("objc-arc-apelim", ObjCARCAPElimPass()) -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/IPO/CMakeLists.txt b/llvm/lib/Transforms/IPO/CMakeLists.txt ---- a/llvm/lib/Transforms/IPO/CMakeLists.txt -+++ b/llvm/lib/Transforms/IPO/CMakeLists.txt -@@ -30,7 +30,6 @@ - LowerTypeTests.cpp - MemProfContextDisambiguation.cpp - MergeFunctions.cpp -- MergeFunctionsIgnoringConst.cpp - ModuleInliner.cpp - OpenMPOpt.cpp - PartialInlining.cpp -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp b/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp ---- a/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp -+++ b/llvm/lib/Transforms/IPO/MergeFunctionsIgnoringConst.cpp -@@ -1,1399 +0,0 @@ --//===--- MergeFunctionsIgnoringConst.cpp - Merge functions ----------------===// --// --// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. --// See https://llvm.org/LICENSE.txt for license information. --// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception --// --//===----------------------------------------------------------------------===// --// --// This pass looks for similar functions that are mergeable and folds them. --// The implementation is similar to LLVM's MergeFunctions pass. Instead of --// merging identical functions, it merges functions which only differ by a few --// constants in certain instructions. --// This is copied from Swift's implementation. --// --// This pass should run after LLVM's MergeFunctions pass, because it works best --// if there are no _identical_ functions in the module. --// Note: it would also work for identical functions but could produce more --// code overhead than the LLVM pass. --// --//===----------------------------------------------------------------------===// -- --#include "llvm/Transforms/IPO/MergeFunctionsIgnoringConst.h" --#include "llvm/ADT/DenseSet.h" --#include "llvm/ADT/FoldingSet.h" --#include "llvm/ADT/Hashing.h" --#include "llvm/ADT/STLExtras.h" --#include "llvm/ADT/SmallSet.h" --#include "llvm/ADT/StableHashing.h" --#include "llvm/ADT/Statistic.h" --#include "llvm/Analysis/ObjCARCUtil.h" --#include "llvm/IR/Attributes.h" --#include "llvm/IR/Constants.h" --#include "llvm/IR/DataLayout.h" --#include "llvm/IR/DebugInfoMetadata.h" --#include "llvm/IR/IRBuilder.h" --#include "llvm/IR/InlineAsm.h" --#include "llvm/IR/Instructions.h" --#include "llvm/IR/Module.h" --#include "llvm/IR/Operator.h" --#include "llvm/IR/StructuralHash.h" --#include "llvm/IR/ValueHandle.h" --#include "llvm/IR/ValueMap.h" --#include "llvm/InitializePasses.h" --#include "llvm/Pass.h" --#include "llvm/Support/CommandLine.h" --#include "llvm/Support/Debug.h" --#include "llvm/Support/ErrorHandling.h" --#include "llvm/Support/FileSystem.h" --#include "llvm/Support/Regex.h" --#include "llvm/Support/raw_ostream.h" --#include "llvm/Transforms/IPO.h" --#include "llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h" --#include -- --using namespace llvm; -- --#define DEBUG_TYPE "mergefunc-ignoring-const" -- --STATISTIC(NumFunctionsMergedIgnoringConst, "Number of functions merged"); --STATISTIC(NumThunksWrittenIgnoringConst, "Number of thunks generated"); -- --static cl::opt EnableAggressiveMergeFunc( -- "enable-aggressive-mergefunc-ignoringconst", cl::init(false), cl::Hidden, -- cl::desc("Enable more aggressive function merger")); -- --static cl::opt NumFunctionsIgnoringConstForSanityCheck( -- "mergefunc-ignoringconst-sanity", -- cl::desc("How many functions in module could be used for " -- "MergeFunctionsIgnoringConst pass sanity check. " -- "'0' disables this check. Works only with '-debug' key."), -- cl::init(0), cl::Hidden); -- --static cl::opt IgnoringConstMergeThreshold( -- "mergefunc-ignoringconst-threshold", -- cl::desc("Functions larger than the threshold are considered for merging." -- "'0' disables function merging at all."), -- cl::init(15), cl::Hidden); -- --cl::opt UseLinkOnceODRLinkageMerging( -- "use-linkonceodr-linkage-merging", cl::init(false), cl::Hidden, -- cl::desc( -- "Use LinkeOnceODR linkage to deduplicate the identical merged function " -- "(default = off)")); -- --cl::opt NoInlineForMergedFunction( -- "no-inline-merged-function", cl::init(false), cl::Hidden, -- cl::desc("set noinline for merged function (default = off)")); -- --static cl::opt -- CastArrayType("merge-cast-array-type", cl::init(false), cl::Hidden, -- cl::desc("support for casting array type (default = off)")); -- --static cl::opt IgnoreMusttailFunction( -- "ignore-musttail-function", cl::init(false), cl::Hidden, -- cl::desc( -- "ignore functions containing callsites with musttail (default = off)")); -- --static cl::opt AlwaysCallThunk( -- "merge-always-call-thunk", cl::init(false), cl::Hidden, -- cl::desc( -- "do not replace callsites and always emit a thunk (default = off)")); -- --static cl::list MergeBlockRegexFilters( -- "merge-block-regex", cl::Optional, -- cl::desc("Block functions from merging if they match the given " -- "regular expression"), -- cl::ZeroOrMore); -- --static cl::list MergeAllowRegexFilters( -- "merge-allow-regex", cl::Optional, -- cl::desc("Allow functions from merging if they match the given " -- "regular expression"), -- cl::ZeroOrMore); -- --bool isEligibleInstrunctionForConstantSharing(const Instruction *I) { -- switch (I->getOpcode()) { -- case Instruction::Load: -- case Instruction::Store: -- case Instruction::Call: -- return true; -- default: { -- if (EnableAggressiveMergeFunc && I->getOpcode() == Instruction::Invoke) -- return true; -- return false; -- } -- } --} -- --/// Returns true if the \OpIdx operand of \p CI is the callee operand. --static bool isCalleeOperand(const CallBase *CI, unsigned OpIdx) { -- return &CI->getCalledOperandUse() == &CI->getOperandUse(OpIdx); --} -- --static bool canParameterizeCallOperand(const CallBase *CI, unsigned OpIdx) { -- if (CI->isInlineAsm()) -- return false; -- Function *Callee = CI->getCalledOperand() -- ? dyn_cast_or_null( -- CI->getCalledOperand()->stripPointerCasts()) -- : nullptr; -- if (Callee) { -- if (Callee->isIntrinsic()) -- return false; -- // objc_msgSend stubs must be called, and can't have their address taken. -- if (Callee->getName().startswith("objc_msgSend$")) -- return false; -- } -- if (isCalleeOperand(CI, OpIdx) && -- CI->getOperandBundle(LLVMContext::OB_ptrauth).has_value()) { -- // The operand is the callee and it has already been signed. Ignore this -- // because we cannot add another ptrauth bundle to the call instruction. -- return false; -- } -- return true; --} -- --bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx) { -- assert(OpIdx < I->getNumOperands() && "Invalid operand index"); -- -- if (!isEligibleInstrunctionForConstantSharing(I)) -- return false; -- -- auto Opnd = I->getOperand(OpIdx); -- if (!isa(Opnd)) -- return false; -- -- if (const auto *CI = dyn_cast(I)) -- return canParameterizeCallOperand(CI, OpIdx); -- -- return true; --} -- --namespace { -- --/// MergeFuncIgnoringConst finds functions which only differ by constants in --/// certain instructions, e.g. resulting from specialized functions of layout --/// compatible types. --/// Such functions are merged by replacing the differing constants by a --/// parameter. The original functions are replaced by thunks which call the --/// merged function with the specific argument constants. --/// --class MergeFuncIgnoringConstImpl { --public: -- MergeFuncIgnoringConstImpl(bool PtrAuthEnabled, unsigned PtrAuthKey, -- std::string Suffix) -- : FnTree(FunctionNodeCmp(&GlobalNumbers)), PtrAuthEnabled(PtrAuthEnabled), -- PtrAuthKey(PtrAuthKey), MergeFuncSuffix(Suffix) {} -- -- bool runImpl(Module &M); -- --private: -- struct FunctionEntry; -- -- /// Describes the set of functions which are considered as "equivalent" (i.e. -- /// only differing by some constants). -- struct EquivalenceClass { -- /// The single-linked list of all functions which are a member of this -- /// equivalence class. -- FunctionEntry *First; -- -- /// A very cheap hash, used to early exit if functions do not match. -- llvm::IRHash Hash; -- -- public: -- // Note the hash is recalculated potentially multiple times, but it is -- // cheap. -- EquivalenceClass(FunctionEntry *First) -- : First(First), Hash(StructuralHash(*First->F)) { -- assert(!First->Next); -- } -- }; -- -- /// The function comparison operator is provided here so that FunctionNodes do -- /// not need to become larger with another pointer. -- class FunctionNodeCmp { -- GlobalNumberState *GlobalNumbers; -- -- public: -- FunctionNodeCmp(GlobalNumberState *GN) : GlobalNumbers(GN) {} -- bool operator()(const EquivalenceClass &LHS, -- const EquivalenceClass &RHS) const { -- // Order first by hashes, then full function comparison. -- if (LHS.Hash != RHS.Hash) -- return LHS.Hash < RHS.Hash; -- FunctionComparatorIgnoringConst FCmp(LHS.First->F, RHS.First->F, -- GlobalNumbers); -- return FCmp.compareIgnoringConsts() == -1; -- } -- }; -- using FnTreeType = std::set; -- -- /// -- struct FunctionEntry { -- FunctionEntry(Function *F, FnTreeType::iterator I) -- : F(F), Next(nullptr), NumUnhandledCallees(0), TreeIter(I), -- IsMerged(false) {} -- -- /// Back-link to the function. -- AssertingVH F; -- -- /// The next function in its equivalence class. -- FunctionEntry *Next; -- -- /// The number of not-yet merged callees. Used to process the merging in -- /// bottom-up call order. -- /// This is only valid in the first entry of an equivalence class. The -- /// counts of all functions in an equivalence class are accumulated in the -- /// first entry. -- int NumUnhandledCallees; -- -- /// The iterator of the function's equivalence class in the FnTree. -- /// It's FnTree.end() if the function is not in an equivalence class. -- FnTreeType::iterator TreeIter; -- -- /// True if this function is already a thunk, calling the merged function. -- bool IsMerged; -- }; -- -- /// Describes an operator of a specific instruction. -- struct OpLocation { -- Instruction *I; -- unsigned OpIndex; -- }; -- -- /// Information for a function. Used during merging. -- struct FunctionInfo { -- -- FunctionInfo(Function *F) -- : F(F), CurrentInst(nullptr), NumParamsNeeded(0) {} -- -- void init() { -- CurrentInst = &*F->begin()->begin(); -- NumParamsNeeded = 0; -- } -- -- /// Advances the current instruction to the next instruction. -- void nextInst() { -- assert(CurrentInst); -- if (CurrentInst->isTerminator()) { -- auto BlockIter = std::next(CurrentInst->getParent()->getIterator()); -- if (BlockIter == F->end()) { -- CurrentInst = nullptr; -- return; -- } -- CurrentInst = &*BlockIter->begin(); -- return; -- } -- CurrentInst = &*std::next(CurrentInst->getIterator()); -- } -- -- /// Returns true if the operand \p OpIdx of the current instruction is the -- /// callee of a call, which needs to be signed if passed as a parameter. -- bool needsPointerSigning(unsigned OpIdx) const { -- if (auto *CI = dyn_cast(CurrentInst)) -- return isCalleeOperand(CI, OpIdx); -- return false; -- } -- -- Function *F; -- -- /// The current instruction while iterating over all instructions. -- Instruction *CurrentInst; -- -- /// Roughly the number of parameters needed if this function would be -- /// merged with the first function of the equivalence class. -- int NumParamsNeeded; -- }; -- -- using FunctionInfos = SmallVector; -- -- /// Describes a parameter which we create to parameterize the merged function. -- struct ParamInfo { -- /// The value of the parameter for all the functions in the equivalence -- /// class. -- SmallVector Values; -- -- /// All uses of the parameter in the merged function. -- SmallVector Uses; -- -- /// The Discriminator for pointer signing. -- /// Only not null if needsPointerSigning is true. -- ConstantInt *Discriminator = nullptr; -- -- /// True if the value is a callee function, which needs to be signed if -- /// passed as a parameter. -- bool NeedsPointerSigning = false; -- -- /// Checks if this parameter can be used to describe an operand in all -- /// functions of the equivalence class. Returns true if all values match -- /// the specific instruction operands in all functions. -- bool matches(const FunctionInfos &FInfos, unsigned OpIdx, -- bool PtrAuthEnabled) const { -- unsigned NumFuncs = FInfos.size(); -- assert(Values.size() == NumFuncs); -- if (PtrAuthEnabled && -- NeedsPointerSigning != FInfos[0].needsPointerSigning(OpIdx)) { -- return false; -- } -- for (unsigned Idx = 0; Idx < NumFuncs; ++Idx) { -- const FunctionInfo &FI = FInfos[Idx]; -- Constant *C = cast(FI.CurrentInst->getOperand(OpIdx)); -- if (Values[Idx] != C) -- return false; -- } -- return true; -- } -- -- /// Computes the Discriminator for pointer signing. -- void computeDiscriminator(LLVMContext &Context) { -- assert(NeedsPointerSigning); -- assert(!Discriminator); -- -- /// Get a hash from the concatenated function names. -- /// The hash is deterministic, because the order of values depends on the -- /// order of functions in the module, which is itself deterministic. -- /// Note that the hash is not part of the ABI, because it's purly used -- /// for pointer authentication between a module-private caller-callee -- /// pair. -- std::string concatenatedCalleeNames; -- for (Constant *value : Values) { -- if (auto *GO = dyn_cast(value)) -- concatenatedCalleeNames += GO->getName(); -- } -- uint64_t rawHash = stable_hash_combine_string(concatenatedCalleeNames); -- IntegerType *discrTy = Type::getInt64Ty(Context); -- Discriminator = ConstantInt::get(discrTy, (rawHash % 0xFFFF) + 1); -- } -- }; -- -- using ParamInfos = SmallVector; -- -- Module *CurrentModule = nullptr; -- -- GlobalNumberState GlobalNumbers; -- -- /// A work queue of functions that may have been modified and should be -- /// analyzed again. -- std::vector Deferred; -- -- /// The set of all distinct functions. Use the insert() and remove() methods -- /// to modify it. The map allows efficient lookup and deferring of Functions. -- FnTreeType FnTree; -- -- ValueMap FuncEntries; -- -- // Maps a function-pointer / Discriminator pair to a corresponding global in -- // the llvm.ptrauth section. -- // This map is used as a cache to not create ptrauth globals twice. -- DenseMap, Constant *> PtrAuthGlobals; -- -- /// True if the architecture has pointer authentication enabled. -- bool PtrAuthEnabled = false; -- -- /// The key for pointer authentication. -- unsigned PtrAuthKey = 0; -- -- std::string MergeFuncSuffix = ".Tm"; -- -- FunctionEntry *getEntry(Function *F) const { return FuncEntries.lookup(F); } -- -- bool isInEquivalenceClass(FunctionEntry *FE) const { -- if (FE->TreeIter != FnTree.end()) { -- return true; -- } -- assert(!FE->Next); -- assert(FE->NumUnhandledCallees == 0); -- return false; -- } -- -- /// Checks the rules of order relation introduced among functions set. -- /// Returns true, if sanity check has been passed, and false if failed. -- bool doSanityCheck(std::vector &Worklist); -- -- /// Updates the NumUnhandledCallees of all user functions of the equivalence -- /// class containing \p FE by \p Delta. -- void updateUnhandledCalleeCount(FunctionEntry *FE, int Delta); -- -- bool tryMergeEquivalenceClass(FunctionEntry *FirstInClass); -- -- FunctionInfo removeFuncWithMostParams(FunctionInfos &FInfos); -- -- bool deriveParams(ParamInfos &Params, FunctionInfos &FInfos, -- unsigned maxParams); -- -- bool numOperandsDiffer(FunctionInfos &FInfos); -- -- bool constsDiffer(const FunctionInfos &FInfos, unsigned OpIdx); -- -- bool tryMapToParameter(FunctionInfos &FInfos, unsigned OpIdx, -- ParamInfos &Params, unsigned maxParams); -- -- void replaceCallWithAddedPtrAuth(CallInst *origCall, Value *newCallee, -- ConstantInt *Discriminator); -- -- void mergeWithParams(const FunctionInfos &FInfos, ParamInfos &Params); -- static void dumpMergeInfo(const FunctionInfos &FInfos, unsigned); -- -- void removeEquivalenceClassFromTree(FunctionEntry *FE); -- -- void writeThunk(Function *ToFunc, Function *Thunk, const ParamInfos &Params, -- unsigned FuncIdx); -- -- bool isPtrAuthEnabled() const { -- // TODO: fix pointer authentication -- return PtrAuthEnabled; -- } -- -- ConstantInt *getPtrAuthKey() { -- // TODO: fix pointer authentication -- return ConstantInt::get(Type::getInt32Ty(CurrentModule->getContext()), -- PtrAuthKey); -- } -- -- /// Returns the value of function \p FuncIdx, and signes it if required. -- Constant *getSignedValue(const ParamInfo &PI, unsigned FuncIdx) { -- Constant *value = PI.Values[FuncIdx]; -- if (!PI.NeedsPointerSigning) -- return value; -- -- auto lookupKey = std::make_pair(value, PI.Discriminator); -- Constant *&ptrAuthGlobal = PtrAuthGlobals[lookupKey]; -- if (!ptrAuthGlobal) { -- // TODO: fix pointer authentication -- } -- return ptrAuthGlobal; -- } -- -- /// Replace all direct calls of Old with calls of New. Will bitcast New if -- /// necessary to make types match. -- bool replaceDirectCallers(Function *Old, Function *New, -- const ParamInfos &Params, unsigned FuncIdx); --}; -- --} // end anonymous namespace -- --bool MergeFuncIgnoringConstImpl::doSanityCheck( -- std::vector &Worklist) { -- if (const unsigned Max = NumFunctionsIgnoringConstForSanityCheck) { -- unsigned TripleNumber = 0; -- bool Valid = true; -- -- dbgs() << "MERGEFUNC-SANITY: Started for first " << Max << " functions.\n"; -- -- unsigned i = 0; -- for (std::vector::iterator I = Worklist.begin(), -- E = Worklist.end(); -- I != E && i < Max; ++I, ++i) { -- unsigned j = i; -- for (std::vector::iterator J = I; J != E && j < Max; -- ++J, ++j) { -- Function *F1 = cast(*I); -- Function *F2 = cast(*J); -- int Res1 = FunctionComparatorIgnoringConst(F1, F2, &GlobalNumbers) -- .compareIgnoringConsts(); -- int Res2 = FunctionComparatorIgnoringConst(F2, F1, &GlobalNumbers) -- .compareIgnoringConsts(); -- -- // If F1 <= F2, then F2 >= F1, otherwise report failure. -- if (Res1 != -Res2) { -- dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber -- << "\n"; -- LLVM_DEBUG(F1->dump()); -- LLVM_DEBUG(F2->dump()); -- Valid = false; -- } -- -- if (Res1 == 0) -- continue; -- -- unsigned k = j; -- for (std::vector::iterator K = J; K != E && k < Max; -- ++k, ++K, ++TripleNumber) { -- if (K == J) -- continue; -- -- Function *F3 = cast(*K); -- int Res3 = FunctionComparatorIgnoringConst(F1, F3, &GlobalNumbers) -- .compareIgnoringConsts(); -- int Res4 = FunctionComparatorIgnoringConst(F2, F3, &GlobalNumbers) -- .compareIgnoringConsts(); -- -- bool Transitive = true; -- -- if (Res1 != 0 && Res1 == Res4) { -- // F1 > F2, F2 > F3 => F1 > F3 -- Transitive = Res3 == Res1; -- } else if (Res3 != 0 && Res3 == -Res4) { -- // F1 > F3, F3 > F2 => F1 > F2 -- Transitive = Res3 == Res1; -- } else if (Res4 != 0 && -Res3 == Res4) { -- // F2 > F3, F3 > F1 => F2 > F1 -- Transitive = Res4 == -Res1; -- } -- -- if (!Transitive) { -- dbgs() << "MERGEFUNC-SANITY: Non-transitive; triple: " -- << TripleNumber << "\n"; -- dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", " -- << Res4 << "\n"; -- LLVM_DEBUG(F1->dump()); -- LLVM_DEBUG(F2->dump()); -- LLVM_DEBUG(F3->dump()); -- Valid = false; -- } -- } -- } -- } -- -- dbgs() << "MERGEFUNC-SANITY: " << (Valid ? "Passed." : "Failed.") << "\n"; -- return Valid; -- } -- return true; --} -- --/// Returns true if functions containing calls to \p F may be merged together. --static bool mayMergeCallsToFunction(Function &F) { -- StringRef Name = F.getName(); -- -- // Calls to dtrace probes must generate unique patchpoints. -- if (Name.startswith("__dtrace")) -- return false; -- -- return true; --} -- --/// Returns the benefit, which is approximately the size of the function. --/// Return 0, if the function should not be merged. --static unsigned getBenefit(Function *F) { -- unsigned Benefit = 0; -- -- // We don't want to merge very small functions, because the overhead of -- // adding creating thunks and/or adding parameters to the call sites -- // outweighs the benefit. -- for (BasicBlock &BB : *F) { -- for (Instruction &I : BB) { -- if (CallBase *CB = dyn_cast(&I)) { -- Function *Callee = CB->getCalledFunction(); -- if (Callee && !mayMergeCallsToFunction(*Callee)) -- return 0; -- if (!Callee || !Callee->isIntrinsic()) { -- Benefit += 5; -- continue; -- } -- } -- Benefit += 1; -- } -- } -- return Benefit; --} -- --/// Returns true if function \p F is eligible for merging. --bool isEligibleFunction(Function *F) { -- if (F->isDeclaration()) -- return false; -- -- if (F->hasFnAttribute(llvm::Attribute::NoMerge)) -- return false; -- -- if (F->hasAvailableExternallyLinkage()) { -- return false; -- } -- -- if (F->getFunctionType()->isVarArg()) { -- return false; -- } -- -- // Check against blocklist. -- if (!MergeBlockRegexFilters.empty()) { -- StringRef FuncName = F->getName(); -- for (const auto &tRegex : MergeBlockRegexFilters) -- if (Regex(tRegex).match(FuncName)) { -- return false; -- } -- } -- // Check against allowlist -- if (!MergeAllowRegexFilters.empty()) { -- StringRef FuncName = F->getName(); -- bool found = false; -- for (const auto &tRegex : MergeAllowRegexFilters) -- if (Regex(tRegex).match(FuncName)) { -- found = true; -- break; -- } -- if (!found) -- return false; -- } -- -- if (F->getCallingConv() == CallingConv::SwiftTail) -- return false; -- -- // if function contains callsites with musttail, if we merge -- // it, the merged function will have the musttail callsite, but -- // the number of parameters can change, thus the parameter count -- // of the callsite will mismatch with the function itself. -- if (IgnoreMusttailFunction) { -- for (const BasicBlock &BB : *F) { -- for (const Instruction &I : BB) { -- const auto *CB = dyn_cast(&I); -- if (CB && CB->isMustTailCall()) -- return false; -- } -- } -- } -- -- unsigned Benefit = getBenefit(F); -- if (Benefit < IgnoringConstMergeThreshold) { -- return false; -- } -- -- return true; --} -- --bool MergeFuncIgnoringConstImpl::runImpl(Module &M) { -- if (IgnoringConstMergeThreshold == 0) -- return false; -- -- CurrentModule = &M; -- -- // TODO: fix pointer authentication -- -- bool Changed = false; -- -- // All functions in the module, ordered by hash. Functions with a unique -- // hash value are easily eliminated. -- std::vector> HashedFuncs; -- -- for (Function &Func : M) { -- if (isEligibleFunction(&Func)) { -- HashedFuncs.push_back({StructuralHash(Func), &Func}); -- } -- } -- -- std::stable_sort(HashedFuncs.begin(), HashedFuncs.end(), -- [](const std::pair &a, -- const std::pair &b) { -- return a.first < b.first; -- }); -- -- std::vector FuncEntryStorage; -- FuncEntryStorage.reserve(HashedFuncs.size()); -- -- auto S = HashedFuncs.begin(); -- for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) { -- -- Function *F = I->second; -- FuncEntryStorage.push_back(FunctionEntry(F, FnTree.end())); -- FunctionEntry &FE = FuncEntryStorage.back(); -- FuncEntries[F] = &FE; -- -- // If the hash value matches the previous value or the next one, we must -- // consider merging it. Otherwise it is dropped and never considered again. -- if ((I != S && std::prev(I)->first == I->first) || -- (std::next(I) != IE && std::next(I)->first == I->first)) { -- Deferred.push_back(WeakTrackingVH(F)); -- } -- } -- -- do { -- std::vector Worklist; -- Deferred.swap(Worklist); -- -- LLVM_DEBUG(dbgs() << "======\nbuild tree: worklist-size=" << Worklist.size() -- << '\n'); -- LLVM_DEBUG(doSanityCheck(Worklist)); -- -- SmallVector FuncsToMerge; -- -- // Insert all candidates into the Worklist. -- for (WeakTrackingVH &I : Worklist) { -- if (!I) -- continue; -- Function *F = cast(I); -- FunctionEntry *FE = getEntry(F); -- assert(!isInEquivalenceClass(FE)); -- -- std::pair Result = FnTree.insert(FE); -- -- FE->TreeIter = Result.first; -- const EquivalenceClass &Eq = *Result.first; -- -- if (Result.second) { -- assert(Eq.First == FE); -- LLVM_DEBUG(dbgs() << " new in tree: " << F->getName() << '\n'); -- } else { -- assert(Eq.First != FE); -- LLVM_DEBUG(dbgs() << " add to existing: " << F->getName() << '\n'); -- // Add the function to the existing equivalence class. -- FE->Next = Eq.First->Next; -- Eq.First->Next = FE; -- // Schedule for merging if the function's equivalence class reaches the -- // size of 2. -- if (!FE->Next) -- FuncsToMerge.push_back(Eq.First); -- } -- } -- LLVM_DEBUG(dbgs() << "merge functions: tree-size=" << FnTree.size() -- << '\n'); -- -- // Figure out the leaf functions. We want to do the merging in bottom-up -- // call order. This ensures that we don't parameterize on callee function -- // names if we don't have to (because the callee may be merged). -- // Note that "leaf functions" refer to the sub-call-graph of functions which -- // are in the FnTree. -- for (FunctionEntry *ToMerge : FuncsToMerge) { -- assert(isInEquivalenceClass(ToMerge)); -- updateUnhandledCalleeCount(ToMerge, 1); -- } -- -- // Check if there are any leaf functions at all. -- bool LeafFound = false; -- for (FunctionEntry *ToMerge : FuncsToMerge) { -- if (ToMerge->NumUnhandledCallees == 0) -- LeafFound = true; -- } -- for (FunctionEntry *ToMerge : FuncsToMerge) { -- if (isInEquivalenceClass(ToMerge)) { -- // Only merge leaf functions (or all functions if all functions are in -- // a call cycle). -- if (ToMerge->NumUnhandledCallees == 0 || !LeafFound) { -- updateUnhandledCalleeCount(ToMerge, -1); -- Changed |= tryMergeEquivalenceClass(ToMerge); -- } else { -- // Non-leaf functions (i.e. functions in a call cycle) may become -- // leaf functions in the next iteration. -- removeEquivalenceClassFromTree(ToMerge); -- } -- } -- } -- } while (!Deferred.empty()); -- -- FnTree.clear(); -- GlobalNumbers.clear(); -- FuncEntries.clear(); -- PtrAuthGlobals.clear(); -- -- return Changed; --} -- --void MergeFuncIgnoringConstImpl::updateUnhandledCalleeCount(FunctionEntry *FE, -- int Delta) { -- // Iterate over all functions of FE's equivalence class. -- do { -- for (Use &U : FE->F->uses()) { -- if (auto *I = dyn_cast(U.getUser())) { -- FunctionEntry *CallerFE = getEntry(I->getFunction()); -- if (CallerFE && CallerFE->TreeIter != FnTree.end()) { -- // Accumulate the count in the first entry of the equivalence class. -- FunctionEntry *Head = CallerFE->TreeIter->First; -- Head->NumUnhandledCallees += Delta; -- } -- } -- } -- FE = FE->Next; -- } while (FE); --} -- --bool MergeFuncIgnoringConstImpl::tryMergeEquivalenceClass( -- FunctionEntry *FirstInClass) { -- // Build the FInfos vector from all functions in the equivalence class. -- FunctionInfos FInfos; -- FunctionEntry *FE = FirstInClass; -- do { -- FInfos.push_back(FunctionInfo(FE->F)); -- FE->IsMerged = true; -- FE = FE->Next; -- } while (FE); -- assert(FInfos.size() >= 2); -- -- // Merged or not: in any case we remove the equivalence class from the FnTree. -- removeEquivalenceClassFromTree(FirstInClass); -- -- // Contains functions which differ too much from the first function (i.e. -- // would need too many parameters). -- FunctionInfos Removed; -- -- bool Changed = false; -- int Try = 0; -- -- unsigned Benefit = getBenefit(FirstInClass->F); -- -- // The bigger the function, the more parameters are allowed. -- unsigned maxParams = std::max(4u, Benefit / 100); -- -- // We need multiple tries if there are some functions in FInfos which differ -- // too much from the first function in FInfos. But we limit the number of -- // tries to a small number, because this is quadratic. -- while (FInfos.size() >= 2 && Try++ < 4) { -- ParamInfos Params; -- bool Merged = deriveParams(Params, FInfos, maxParams); -- if (Merged) { -- mergeWithParams(FInfos, Params); -- Changed = true; -- } else { -- // We ran out of parameters. Remove the function from the set which -- // differs most from the first function. -- Removed.push_back(removeFuncWithMostParams(FInfos)); -- } -- if (Merged || FInfos.size() < 2) { -- // Try again with the functions which were removed from the original set. -- FInfos.swap(Removed); -- Removed.clear(); -- } -- } -- return Changed; --} -- --/// Remove the function from \p FInfos which needs the most parameters. Add the --/// removed function to --MergeFuncIgnoringConstImpl::FunctionInfo --MergeFuncIgnoringConstImpl::removeFuncWithMostParams(FunctionInfos &FInfos) { -- FunctionInfos::iterator MaxIter = FInfos.end(); -- for (auto Iter = FInfos.begin(), End = FInfos.end(); Iter != End; ++Iter) { -- if (MaxIter == FInfos.end() || -- Iter->NumParamsNeeded > MaxIter->NumParamsNeeded) { -- MaxIter = Iter; -- } -- } -- FunctionInfo Removed = *MaxIter; -- FInfos.erase(MaxIter); -- return Removed; --} -- --/// Finds the set of parameters which are required to merge the functions in --/// \p FInfos. --/// Returns true on success, i.e. the functions in \p FInfos can be merged with --/// the parameters returned in \p Params. --bool MergeFuncIgnoringConstImpl::deriveParams(ParamInfos &Params, -- FunctionInfos &FInfos, -- unsigned maxParams) { -- for (FunctionInfo &FI : FInfos) -- FI.init(); -- -- FunctionInfo &FirstFI = FInfos.front(); -- -- // Iterate over all instructions synchronously in all functions. -- do { -- if (isEligibleInstrunctionForConstantSharing(FirstFI.CurrentInst)) { -- -- // Here we handle a rare corner case which needs to be explained: -- // Usually the number of operands match, because otherwise the functions -- // in FInfos would not be in the same equivalence class. There is only one -- // exception to that: If the current instruction is a call to a function, -- // which was merged in the previous iteration (in -- // tryMergeEquivalenceClass) then the call could be replaced and has more -- // arguments than the original call. -- if (numOperandsDiffer(FInfos)) { -- assert(isa(FirstFI.CurrentInst) && -- "only calls are expected to differ in number of operands"); -- return false; -- } -- -- for (unsigned OpIdx = 0, NumOps = FirstFI.CurrentInst->getNumOperands(); -- OpIdx != NumOps; ++OpIdx) { -- -- if (constsDiffer(FInfos, OpIdx)) { -- // This instruction has operands which differ in at least some -- // functions. So we need to parameterize it. -- if (!tryMapToParameter(FInfos, OpIdx, Params, maxParams)) { -- // We ran out of parameters. -- return false; -- } -- } -- } -- } -- // Go to the next instruction in all functions. -- for (FunctionInfo &FI : FInfos) -- FI.nextInst(); -- } while (FirstFI.CurrentInst); -- -- return true; --} -- --/// Returns true if the number of operands of the current instruction differs. --bool MergeFuncIgnoringConstImpl::numOperandsDiffer(FunctionInfos &FInfos) { -- unsigned numOps = FInfos[0].CurrentInst->getNumOperands(); -- for (const FunctionInfo &FI : ArrayRef(FInfos).drop_front(1)) { -- if (FI.CurrentInst->getNumOperands() != numOps) -- return true; -- } -- return false; --} -- --/// Returns true if the \p OpIdx's constant operand in the current instruction --/// does differ in any of the functions in \p FInfos. --bool MergeFuncIgnoringConstImpl::constsDiffer(const FunctionInfos &FInfos, -- unsigned OpIdx) { -- Constant *CommonConst = nullptr; -- -- for (const FunctionInfo &FI : FInfos) { -- Value *Op = FI.CurrentInst->getOperand(OpIdx); -- if (auto *C = dyn_cast(Op)) { -- if (!CommonConst) { -- CommonConst = C; -- } else if (EnableAggressiveMergeFunc && -- isa(CommonConst) && -- isa(C)) { -- // if both are null pointer, and if they are different constants -- // due to type, still treat them as the same. -- } else if (C != CommonConst) { -- return true; -- } -- } -- } -- return false; --} -- --/// Create a new parameter for differing operands or try to reuse an existing --/// parameter. --/// Returns true if a parameter could be created or found without exceeding the --/// maximum number of parameters. --bool MergeFuncIgnoringConstImpl::tryMapToParameter(FunctionInfos &FInfos, -- unsigned OpIdx, -- ParamInfos &Params, -- unsigned maxParams) { -- ParamInfo *Matching = nullptr; -- // Try to find an existing parameter which exactly matches the differing -- // operands of the current instruction. -- for (ParamInfo &PI : Params) { -- if (PI.matches(FInfos, OpIdx, isPtrAuthEnabled())) { -- Matching = &PI; -- break; -- } -- } -- if (!Matching) { -- // We need a new parameter. -- // Check if we are within the limit. -- if (Params.size() >= maxParams) -- return false; -- -- Params.resize(Params.size() + 1); -- Matching = &Params.back(); -- // Store the constant values into the new parameter. -- Constant *FirstC = cast(FInfos[0].CurrentInst->getOperand(OpIdx)); -- for (FunctionInfo &FI : FInfos) { -- Constant *C = cast(FI.CurrentInst->getOperand(OpIdx)); -- Matching->Values.push_back(C); -- if (C != FirstC) -- FI.NumParamsNeeded += 1; -- } -- if (isPtrAuthEnabled()) -- Matching->NeedsPointerSigning = FInfos[0].needsPointerSigning(OpIdx); -- } -- /// Remember where the parameter is needed when we build our merged function. -- Matching->Uses.push_back({FInfos[0].CurrentInst, OpIdx}); -- return true; --} -- --/// Copy \p origCall with a \p newCalle and add a ptrauth bundle with \p --/// Discriminator. --void MergeFuncIgnoringConstImpl::replaceCallWithAddedPtrAuth( -- CallInst *origCall, Value *newCallee, ConstantInt *Discriminator) { -- SmallVector bundles; -- origCall->getOperandBundlesAsDefs(bundles); -- ConstantInt *key = getPtrAuthKey(); -- llvm::Value *bundleArgs[] = {key, Discriminator}; -- bundles.emplace_back("ptrauth", bundleArgs); -- -- SmallVector copiedArgs; -- for (Value *op : origCall->args()) { -- copiedArgs.push_back(op); -- } -- -- auto *newCall = -- CallInst::Create(origCall->getFunctionType(), newCallee, copiedArgs, -- bundles, origCall->getName(), origCall); -- newCall->setAttributes(origCall->getAttributes()); -- newCall->setTailCallKind(origCall->getTailCallKind()); -- newCall->setCallingConv(origCall->getCallingConv()); -- origCall->replaceAllUsesWith(newCall); -- origCall->eraseFromParent(); --} -- --void MergeFuncIgnoringConstImpl::dumpMergeInfo(const FunctionInfos &FInfos, -- unsigned paramSize) { -- std::set oHashes; -- std::vector funcLocs; -- Function *OrigFunc = nullptr; -- for (const auto &FInfo : FInfos) { -- OrigFunc = FInfo.F; -- -- llvm::IRHash origHash = StructuralHash(*OrigFunc); -- oHashes.insert(origHash); -- -- // Print debug location. -- std::string Result; -- raw_string_ostream DbgLocOS(Result); -- if (DISubprogram *DIS = OrigFunc->getSubprogram()) { -- DebugLoc FuncDbgLoc = -- DILocation::get(DIS->getContext(), DIS->getScopeLine(), 0, DIS); -- FuncDbgLoc.print(DbgLocOS); -- DbgLocOS.flush(); -- } -- std::string singleLine = -- "# functionLoc " + -- std::to_string(GlobalValue::getGUID(OrigFunc->getName())) + " " + -- Result + " " + std::string(OrigFunc->getName()) + "\n"; -- funcLocs.push_back(singleLine); -- } --} -- --/// Merge all functions in \p FInfos by creating thunks which call the single --/// merged function with additional parameters. --void MergeFuncIgnoringConstImpl::mergeWithParams(const FunctionInfos &FInfos, -- ParamInfos &Params) { -- // We reuse the body of the first function for the new merged function. -- Function *FirstF = FInfos.front().F; -- -- // Build the type for the merged function. This will be the type of the -- // original function (FirstF) but with the additional parameter which are -- // needed to parameterize the merged function. -- FunctionType *OrigTy = FirstF->getFunctionType(); -- SmallVector ParamTypes(OrigTy->param_begin(), OrigTy->param_end()); -- -- for (const ParamInfo &PI : Params) { -- ParamTypes.push_back(PI.Values[0]->getType()); -- } -- -- FunctionType *funcType = -- FunctionType::get(OrigTy->getReturnType(), ParamTypes, false); -- -- // Create the new function. -- Function *NewFunction = Function::Create(funcType, FirstF->getLinkage(), -- FirstF->getName() + MergeFuncSuffix); -- if (auto *SP = FirstF->getSubprogram()) -- NewFunction->setSubprogram(SP); -- NewFunction->copyAttributesFrom(FirstF); -- // NOTE: this function is not externally available, do ensure that we reset -- // the DLL storage -- NewFunction->setDLLStorageClass(GlobalValue::DefaultStorageClass); -- if (UseLinkOnceODRLinkageMerging) -- NewFunction->setLinkage(GlobalValue::LinkOnceODRLinkage); -- else -- NewFunction->setLinkage(GlobalValue::InternalLinkage); -- if (NoInlineForMergedFunction) -- NewFunction->addFnAttr(Attribute::NoInline); -- -- // Insert the new function after the last function in the equivalence class. -- FirstF->getParent()->getFunctionList().insert( -- std::next(FInfos[1].F->getIterator()), NewFunction); -- -- LLVM_DEBUG(dbgs() << " Merge into " << NewFunction->getName() << '\n'); -- -- // Move the body of FirstF into the NewFunction. -- NewFunction->splice(NewFunction->begin(), FirstF); -- -- auto NewArgIter = NewFunction->arg_begin(); -- for (Argument &OrigArg : FirstF->args()) { -- Argument &NewArg = *NewArgIter++; -- OrigArg.replaceAllUsesWith(&NewArg); -- } -- unsigned numOrigArgs = FirstF->arg_size(); -- -- SmallPtrSet SelfReferencingFunctions; -- -- // Replace all differing operands with a parameter. -- for (unsigned paramIdx = 0; paramIdx < Params.size(); ++paramIdx) { -- const ParamInfo &PI = Params[paramIdx]; -- Argument *NewArg = NewFunction->getArg(numOrigArgs + paramIdx); -- -- if (!PI.NeedsPointerSigning) { -- for (const OpLocation &OL : PI.Uses) { -- OL.I->setOperand(OL.OpIndex, NewArg); -- } -- } -- // Collect all functions which are referenced by any parameter. -- for (Value *V : PI.Values) { -- if (auto *F = dyn_cast(V)) -- SelfReferencingFunctions.insert(F); -- } -- } -- -- // Replace all differing operands, which need pointer signing, with a -- // parameter. -- // We need to do that after all other parameters, because here we replace -- // call instructions, which must be live in case it has another constant to -- // be replaced. -- for (unsigned paramIdx = 0; paramIdx < Params.size(); ++paramIdx) { -- ParamInfo &PI = Params[paramIdx]; -- if (PI.NeedsPointerSigning) { -- PI.computeDiscriminator(NewFunction->getContext()); -- for (const OpLocation &OL : PI.Uses) { -- auto *origCall = cast(OL.I); -- Argument *newCallee = NewFunction->getArg(numOrigArgs + paramIdx); -- replaceCallWithAddedPtrAuth(origCall, newCallee, PI.Discriminator); -- } -- } -- } -- -- for (unsigned FIdx = 0, NumFuncs = FInfos.size(); FIdx < NumFuncs; ++FIdx) { -- Function *OrigFunc = FInfos[FIdx].F; -- // Don't try to replace all callers of functions which are used as -- // parameters because we must not delete such functions. -- if (SelfReferencingFunctions.count(OrigFunc) == 0 && -- replaceDirectCallers(OrigFunc, NewFunction, Params, FIdx)) { -- // We could replace all uses (and the function is not externally visible), -- // so we can delete the original function. -- auto Iter = FuncEntries.find(OrigFunc); -- assert(Iter != FuncEntries.end()); -- assert(!isInEquivalenceClass(&*Iter->second)); -- Iter->second->F = nullptr; -- FuncEntries.erase(Iter); -- LLVM_DEBUG(dbgs() << " Erase " << OrigFunc->getName() << '\n'); -- OrigFunc->eraseFromParent(); -- } else { -- // Otherwise we need a thunk which calls the merged function. -- writeThunk(NewFunction, OrigFunc, Params, FIdx); -- } -- ++NumFunctionsMergedIgnoringConst; -- } --} -- --/// Remove all functions of \p FE's equivalence class from FnTree. Add them to --/// Deferred so that we'll look at them in the next round. --void MergeFuncIgnoringConstImpl::removeEquivalenceClassFromTree( -- FunctionEntry *FE) { -- if (!isInEquivalenceClass(FE)) -- return; -- -- FnTreeType::iterator Iter = FE->TreeIter; -- FunctionEntry *Unlink = Iter->First; -- Unlink->NumUnhandledCallees = 0; -- while (Unlink) { -- LLVM_DEBUG(dbgs() << " remove from tree: " << Unlink->F->getName() -- << '\n'); -- if (!Unlink->IsMerged) -- Deferred.emplace_back(Unlink->F); -- Unlink->TreeIter = FnTree.end(); -- assert(Unlink->NumUnhandledCallees == 0); -- FunctionEntry *NextEntry = Unlink->Next; -- Unlink->Next = nullptr; -- Unlink = NextEntry; -- } -- FnTree.erase(Iter); --} -- --// Helper for writeThunk, --// Selects proper bitcast operation, --// but a bit simpler then CastInst::getCastOpcode. --Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { -- Type *SrcTy = V->getType(); -- if (SrcTy->isStructTy()) { -- assert(DestTy->isStructTy()); -- assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements()); -- Value *Result = UndefValue::get(DestTy); -- for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) { -- Value *Element = -- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)), -- DestTy->getStructElementType(I)); -- -- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I)); -- } -- return Result; -- } -- assert(!DestTy->isStructTy()); -- if (CastArrayType) { -- if (auto *SrcAT = dyn_cast(SrcTy)) { -- auto *DestAT = dyn_cast(DestTy); -- assert(DestAT); -- assert(SrcAT->getNumElements() == DestAT->getNumElements()); -- Value *Result = UndefValue::get(DestTy); -- for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) { -- Value *Element = -- createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)), -- DestAT->getElementType()); -- -- Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I)); -- } -- return Result; -- } -- assert(!DestTy->isArrayTy()); -- } -- if (SrcTy->isIntegerTy() && DestTy->isPointerTy()) -- return Builder.CreateIntToPtr(V, DestTy); -- else if (SrcTy->isPointerTy() && DestTy->isIntegerTy()) -- return Builder.CreatePtrToInt(V, DestTy); -- else -- return Builder.CreateBitCast(V, DestTy); --} -- --/// Replace \p Thunk with a simple tail call to \p ToFunc. Also add parameters --/// to the call to \p ToFunc, which are defined by the FuncIdx's value in --/// \p Params. --void MergeFuncIgnoringConstImpl::writeThunk(Function *ToFunc, Function *Thunk, -- const ParamInfos &Params, -- unsigned FuncIdx) { -- // Delete the existing content of Thunk. -- Thunk->dropAllReferences(); -- -- BasicBlock *BB = BasicBlock::Create(Thunk->getContext(), "", Thunk); -- IRBuilder<> Builder(BB); -- -- SmallVector Args; -- unsigned ParamIdx = 0; -- FunctionType *ToFuncTy = ToFunc->getFunctionType(); -- -- // Add arguments which are passed through Thunk. -- for (Argument &AI : Thunk->args()) { -- Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx))); -- ++ParamIdx; -- } -- // Add new arguments defined by Params. -- for (const ParamInfo &PI : Params) { -- assert(ParamIdx < ToFuncTy->getNumParams()); -- Constant *param = getSignedValue(PI, FuncIdx); -- Args.push_back( -- createCast(Builder, param, ToFuncTy->getParamType(ParamIdx))); -- ++ParamIdx; -- } -- -- CallInst *CI = Builder.CreateCall(ToFunc, Args); -- bool isSwiftTailCall = ToFunc->getCallingConv() == CallingConv::SwiftTail && -- Thunk->getCallingConv() == CallingConv::SwiftTail; -- CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail -- : llvm::CallInst::TCK_Tail); -- CI->setCallingConv(ToFunc->getCallingConv()); -- CI->setAttributes(ToFunc->getAttributes()); -- if (Thunk->getReturnType()->isVoidTy()) { -- Builder.CreateRetVoid(); -- } else { -- Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType())); -- } -- -- LLVM_DEBUG(dbgs() << " writeThunk: " << Thunk->getName() << '\n'); -- ++NumThunksWrittenIgnoringConst; --} -- --static llvm::AttributeList --fixUpTypesInByValAndStructRetAttributes(llvm::FunctionType *fnType, -- llvm::AttributeList attrList) { -- auto &context = fnType->getContext(); -- if (!context.supportsTypedPointers()) -- return attrList; -- -- for (unsigned i = 0; i < fnType->getNumParams(); ++i) { -- auto paramTy = fnType->getParamType(i); -- auto attrListIndex = llvm::AttributeList::FirstArgIndex + i; -- if (attrList.hasParamAttr(i, llvm::Attribute::StructRet) && -- paramTy->getNonOpaquePointerElementType() != -- attrList.getParamStructRetType(i)) -- attrList = attrList.replaceAttributeTypeAtIndex( -- context, attrListIndex, llvm::Attribute::StructRet, -- paramTy->getNonOpaquePointerElementType()); -- if (attrList.hasParamAttr(i, llvm::Attribute::ByVal) && -- paramTy->getNonOpaquePointerElementType() != -- attrList.getParamByValType(i)) -- attrList = attrList.replaceAttributeTypeAtIndex( -- context, attrListIndex, llvm::Attribute::ByVal, -- paramTy->getNonOpaquePointerElementType()); -- } -- return attrList; --} -- --/// Replace direct callers of Old with New. Also add parameters to the call to --/// \p New, which are defined by the FuncIdx's value in \p Params. --bool MergeFuncIgnoringConstImpl::replaceDirectCallers(Function *Old, -- Function *New, -- const ParamInfos &Params, -- unsigned FuncIdx) { -- bool AllReplaced = true; -- -- SmallVector Callers; -- -- for (Use &U : Old->uses()) { -- auto *I = dyn_cast(U.getUser()); -- if (!I) { -- AllReplaced = false; -- continue; -- } -- FunctionEntry *FE = getEntry(I->getFunction()); -- if (FE) -- removeEquivalenceClassFromTree(FE); -- -- auto *CI = dyn_cast(I); -- if (!CI || CI->getCalledOperand() != Old) { -- AllReplaced = false; -- continue; -- } -- Callers.push_back(CI); -- } -- if (!AllReplaced) -- return false; -- -- // When AlwaysCallThunk is true, return false so a thunk will be emitted, also -- // do not replace callsites. -- if (AlwaysCallThunk) -- return false; -- -- for (CallInst *CI : Callers) { -- auto &Context = New->getContext(); -- auto NewPAL = New->getAttributes(); -- -- SmallVector OldParamTypes; -- SmallVector NewArgs; -- SmallVector NewArgAttrs; -- IRBuilder<> Builder(CI); -- -- FunctionType *NewFuncTy = New->getFunctionType(); -- (void)NewFuncTy; -- unsigned ParamIdx = 0; -- -- // Add the existing parameters. -- for (Value *OldArg : CI->args()) { -- NewArgAttrs.push_back(NewPAL.getParamAttrs(ParamIdx)); -- NewArgs.push_back(OldArg); -- OldParamTypes.push_back(OldArg->getType()); -- ++ParamIdx; -- } -- // Add the new parameters. -- for (const ParamInfo &PI : Params) { -- assert(ParamIdx < NewFuncTy->getNumParams()); -- Constant *ArgValue = getSignedValue(PI, FuncIdx); -- assert(ArgValue != Old && "should not try to replace all callers of self " -- "referencing functions"); -- NewArgs.push_back(ArgValue); -- OldParamTypes.push_back(ArgValue->getType()); -- ++ParamIdx; -- } -- -- auto *FType = FunctionType::get(Old->getFunctionType()->getReturnType(), -- OldParamTypes, false); -- auto *FPtrType = PointerType::get( -- FType, cast(New->getType())->getAddressSpace()); -- -- Value *Callee = ConstantExpr::getBitCast(New, FPtrType); -- CallInst *NewCI; -- if (objcarc::hasAttachedCallOpBundle(CI)) { -- Value *BundleArgs[] = {*objcarc::getAttachedARCFunction(CI)}; -- OperandBundleDef OB("clang.arc.attachedcall", BundleArgs); -- NewCI = Builder.CreateCall(FType, Callee, NewArgs, {OB}); -- } else { -- NewCI = Builder.CreateCall(FType, Callee, NewArgs); -- } -- NewCI->setCallingConv(CI->getCallingConv()); -- // Don't transfer attributes from the function to the callee. Function -- // attributes typically aren't relevant to the calling convention or ABI. -- auto newAttrList = AttributeList::get(Context, /*FnAttrs=*/AttributeSet(), -- NewPAL.getRetAttrs(), NewArgAttrs); -- newAttrList = fixUpTypesInByValAndStructRetAttributes(FType, newAttrList); -- NewCI->setAttributes(newAttrList); -- if (IgnoreMusttailFunction && CI->isMustTailCall()) { -- // replace a callsite with musttail. -- llvm::errs() << "callsite has musttail in newF " << New->getName() -- << "\n"; -- } -- NewCI->copyMetadata(*CI); -- CI->replaceAllUsesWith(NewCI); -- CI->eraseFromParent(); -- } -- assert(Old->use_empty() && "should have replaced all uses of old function"); -- return Old->hasLocalLinkage(); --} -- --PreservedAnalyses MergeFuncIgnoringConstPass::run(Module &M, -- ModuleAnalysisManager &MAM) { -- if (MergeFuncIgnoringConstImpl(PtrAuthEnabled, PtrAuthKey, MergeFuncSuffix) -- .runImpl(M)) -- return PreservedAnalyses::none(); -- return PreservedAnalyses::all(); --} -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt ---- a/llvm/lib/Transforms/Utils/CMakeLists.txt -+++ b/llvm/lib/Transforms/Utils/CMakeLists.txt -@@ -27,7 +27,6 @@ - FixIrreducible.cpp - FlattenCFG.cpp - FunctionComparator.cpp -- FunctionComparatorIgnoringConst.cpp - FunctionImportUtils.cpp - GlobalStatus.cpp - GuardUtils.cpp -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp b/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp ---- a/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp -+++ b/llvm/lib/Transforms/Utils/FunctionComparatorIgnoringConst.cpp -@@ -1,107 +0,0 @@ --//===--- FunctionComparatorIgnoringConst.cpp - Function Comparator --------===// --// --// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. --// See https://llvm.org/LICENSE.txt for license information. --// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception --// --//===----------------------------------------------------------------------===// --// --//===----------------------------------------------------------------------===// -- --#include "llvm/Transforms/Utils/FunctionComparatorIgnoringConst.h" --#include "llvm/IR/Instructions.h" --#include "llvm/Transforms/Utils/MergeFunctionsIgnoringConst.h" -- --using namespace llvm; -- --int FunctionComparatorIgnoringConst::cmpOperandsIgnoringConsts( -- const Instruction *L, const Instruction *R, unsigned opIdx) { -- Value *OpL = L->getOperand(opIdx); -- Value *OpR = R->getOperand(opIdx); -- -- int Res = cmpValues(OpL, OpR); -- if (Res == 0) -- return Res; -- -- if (!isa(OpL) || !isa(OpR)) -- return Res; -- -- if (!isEligibleOperandForConstantSharing(L, opIdx) || -- !isEligibleOperandForConstantSharing(R, opIdx)) -- return Res; -- -- if (cmpTypes(OpL->getType(), OpR->getType())) -- return Res; -- -- return 0; --} -- --// Test whether two basic blocks have equivalent behavior. --int FunctionComparatorIgnoringConst::cmpBasicBlocksIgnoringConsts( -- const BasicBlock *BBL, const BasicBlock *BBR, -- const std::set> *InstOpndIndex) { -- BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); -- BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end(); -- -- do { -- bool needToCmpOperands = true; -- if (int Res = cmpOperations(&*InstL, &*InstR, needToCmpOperands)) -- return Res; -- if (needToCmpOperands) { -- assert(InstL->getNumOperands() == InstR->getNumOperands()); -- -- for (unsigned i = 0, e = InstL->getNumOperands(); i != e; ++i) { -- // When a set for (instruction, operand) index pairs is given, we only -- // ignore constants located at such indices. Otherwise, we precisely -- // compare the operands. -- if (InstOpndIndex && !InstOpndIndex->count(std::make_pair(Index, i))) { -- Value *OpL = InstL->getOperand(i); -- Value *OpR = InstR->getOperand(i); -- if (int Res = cmpValues(OpL, OpR)) -- return Res; -- } -- if (int Res = cmpOperandsIgnoringConsts(&*InstL, &*InstR, i)) -- return Res; -- // cmpValues should ensure this is true. -- assert(cmpTypes(InstL->getOperand(i)->getType(), -- InstR->getOperand(i)->getType()) == 0); -- } -- } -- ++Index; -- ++InstL, ++InstR; -- } while (InstL != InstLE && InstR != InstRE); -- -- if (InstL != InstLE && InstR == InstRE) -- return 1; -- if (InstL == InstLE && InstR != InstRE) -- return -1; -- return 0; --} -- --// Test whether the two functions have equivalent behavior. --int FunctionComparatorIgnoringConst::compareIgnoringConsts( -- const std::set> *InstOpndIndex) { -- beginCompare(); -- Index = 0; -- -- if (int Res = compareSignature()) -- return Res; -- -- Function::const_iterator LIter = FnL->begin(), LEnd = FnL->end(); -- Function::const_iterator RIter = FnR->begin(), REnd = FnR->end(); -- -- do { -- const BasicBlock *BBL = &*LIter; -- const BasicBlock *BBR = &*RIter; -- -- if (int Res = cmpValues(BBL, BBR)) -- return Res; -- -- if (int Res = cmpBasicBlocksIgnoringConsts(BBL, BBR, InstOpndIndex)) -- return Res; -- -- ++LIter, ++RIter; -- } while (LIter != LEnd && RIter != REnd); -- -- return 0; --} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/MergeFuncIgnoringConst/merge_func.ll b/llvm/test/Transforms/MergeFuncIgnoringConst/merge_func.ll ---- a/llvm/test/Transforms/MergeFuncIgnoringConst/merge_func.ll -+++ b/llvm/test/Transforms/MergeFuncIgnoringConst/merge_func.ll -@@ -1,532 +0,0 @@ --; RUN: opt -S -mergefunc-ignoringconst-threshold=4 -passes=mergefunc-ignoring-const %s | FileCheck %s -- --@g1 = external global i32 --@g2 = external global i32 --@g3 = external global i32 --@g4 = external global i32 --@g5 = external global i32 -- --; Test the most trivial example. -- --; CHECK-LABEL: define i32 @simple_func1(i32 %x, i32 %y) --; CHECK: %1 = tail call i32 @simple_func1.Tm(i32 %x, i32 %y, ptr @g1) --; CHECK: ret i32 %1 --define i32 @simple_func1(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %l = load i32, i32* @g1, align 4 -- %sum3 = add i32 %sum2, %y -- ret i32 %sum3 --} -- --; CHECK-LABEL: define i32 @simple_func2(i32 %x, i32 %y) --; CHECK: %1 = tail call i32 @simple_func1.Tm(i32 %x, i32 %y, ptr @g2) --; CHECK: ret i32 %1 --define i32 @simple_func2(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %l = load i32, i32* @g2, align 4 -- %sum3 = add i32 %sum2, %y -- ret i32 %sum3 --} -- --; CHECK-LABEL: define internal i32 @simple_func1.Tm(i32 %0, i32 %1, ptr %2) --; CHECK: %l = load i32, ptr %2 --; CHECK: ret -- -- --; Merge 3 functions with 3 types of differing instructions: load, store and call. -- --; CHECK-LABEL: define i32 @func1_of_3(i32 %x) --; CHECK: %1 = tail call i32 @func1_of_3.Tm(i32 %x, ptr @g1, ptr @g1, ptr @callee1) --; CHECK: ret i32 %1 --define i32 @func1_of_3(i32 %x) { -- %l1 = load i32, i32* @g1, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g1, align 4 -- %sum2 = add i32 %sum, %l2 -- store i32 %sum2, i32 *@g1, align 4 -- call void @callee1(i32 %sum2) -- %sum3 = add i32 %sum2, %l2 -- ret i32 %sum3 --} -- --; CHECK-LABEL: define i32 @func2_of_3(i32 %x) --; CHECK: %1 = tail call i32 @func1_of_3.Tm(i32 %x, ptr @g2, ptr @g2, ptr @callee2) --; CHECK: ret i32 %1 --define i32 @func2_of_3(i32 %x) { -- %l1 = load i32, i32* @g2, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g2, align 4 -- %sum2 = add i32 %sum, %l2 -- store i32 %sum2, i32 *@g2, align 4 -- call void @callee2(i32 %sum2) -- %sum3 = add i32 %sum2, %l2 -- ret i32 %sum3 --} -- --; CHECK-LABEL: define i32 @func3_of_3(i32 %x) --; CHECK: %1 = tail call i32 @func1_of_3.Tm(i32 %x, ptr @g3, ptr @g1, ptr @callee3) --; CHECK: ret i32 %1 --define i32 @func3_of_3(i32 %x) { -- %l1 = load i32, i32* @g3, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g1, align 4 -- %sum2 = add i32 %sum, %l2 -- store i32 %sum2, i32 *@g3, align 4 -- call void @callee3(i32 %sum2) -- %sum3 = add i32 %sum2, %l2 -- ret i32 %sum3 --} -- --; CHECK-LABEL: define internal i32 @func1_of_3.Tm(i32 %0, ptr %1, ptr %2, ptr %3) --; CHECK: %l1 = load i32, ptr %1 --; CHECK: %l2 = load i32, ptr %2 --; CHECK: store i32 %sum2, ptr %1 --; CHECK: call void %3(i32 %sum2) --; CHECK: ret -- --declare void @callee1(i32 %x) --declare void @callee2(i32 %x) --declare void @callee3(i32 %x) -- --; Preserve attributes -- --; CHECK-LABEL: define void @sret_func1(ptr sret(i32) %p, i32 %x, i32 %y) --; CHECK: tail call void @sret_func1.Tm(ptr sret(i32) %p, i32 %x, i32 %y, ptr @g1) --; CHECK: ret void --define void @sret_func1(i32* sret(i32) %p, i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %l = load i32, i32* @g1, align 4 -- %sum2 = add i32 %sum, %l -- store i32 %sum2, i32* %p -- ret void --} -- --; CHECK-LABEL: define void @sret_func2(ptr sret(i32) %p, i32 %x, i32 %y) --; CHECK: tail call void @sret_func1.Tm(ptr sret(i32) %p, i32 %x, i32 %y, ptr @g2) --; CHECK: ret void --define void @sret_func2(i32* sret(i32) %p, i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %l = load i32, i32* @g2, align 4 -- %sum2 = add i32 %sum, %l -- store i32 %sum2, i32* %p -- ret void --} -- --; CHECK-LABEL: define internal void @sret_func1.Tm(ptr sret(i32) %0, i32 %1, i32 %2, ptr %3) --; CHECK: %l = load i32, ptr %3, align 4 --; CHECK: store i32 %sum2, ptr %0 --; CHECK: ret -- -- --; Don't merge all functions, because we would generate too many parameters. --; Instead merge those functions which match best. -- --; CHECK-LABEL: define i32 @func1_merged_with3(i32 %x) --; CHECK: %1 = tail call i32 @func1_merged_with3.Tm(i32 %x, ptr @g1) --; CHECK: ret i32 %1 --define i32 @func1_merged_with3(i32 %x) { -- %l1 = load i32, i32* @g1, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g2, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g3, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g4, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g5, align 4 -- %sum5 = add i32 %sum4, %l2 -- ret i32 %sum5 --} -- --; CHECK-LABEL: define i32 @func2_merged_with4(i32 %x) --; CHECK: %1 = tail call i32 @func2_merged_with4.Tm(i32 %x, ptr @g2) --; CHECK: ret i32 %1 --define i32 @func2_merged_with4(i32 %x) { -- %l1 = load i32, i32* @g2, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g3, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g4, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g5, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g1, align 4 -- %sum5 = add i32 %sum4, %l2 -- ret i32 %sum5 --} -- --; CHECK-LABEL: define i32 @func3_merged_with1(i32 %x) --; CHECK: %1 = tail call i32 @func1_merged_with3.Tm(i32 %x, ptr @g2) --; CHECK: ret i32 %1 --define i32 @func3_merged_with1(i32 %x) { -- %l1 = load i32, i32* @g2, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g2, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g3, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g4, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g5, align 4 -- %sum5 = add i32 %sum4, %l2 -- ret i32 %sum5 --} -- --; CHECK-LABEL: define internal i32 @func1_merged_with3.Tm(i32 %0, ptr %1) --; CHECK: load i32, ptr %1, align 4 --; CHECK: load i32, ptr @g2, align 4 --; CHECK: load i32, ptr @g3, align 4 --; CHECK: load i32, ptr @g4, align 4 --; CHECK: load i32, ptr @g5, align 4 --; CHECK: ret i32 -- --; CHECK-LABEL: define i32 @func4_merged_with2(i32 %x) { --; CHECK: %1 = tail call i32 @func2_merged_with4.Tm(i32 %x, ptr @g1) --; CHECK: ret i32 %1 --define i32 @func4_merged_with2(i32 %x) { -- %l1 = load i32, i32* @g1, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g3, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g4, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g5, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g1, align 4 -- %sum5 = add i32 %sum4, %l2 -- ret i32 %sum5 --} -- -- --; The same example as above, but we cannot merge func2 with func4, because --; func4 calls func1 (which is merged with func2 in the first iteration). -- --declare i32 @get_int(i32 %x) -- --; CHECK-LABEL: define i32 @Function1_merged_with_3(i32 %x) --; CHECK: %1 = tail call i32 @Function1_merged_with_3.Tm(i32 %x, ptr @g1) --; CHECK: ret i32 %1 --define i32 @Function1_merged_with_3(i32 %x) { -- %l1 = load i32, i32* @g1, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g2, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g3, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g4, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g5, align 4 -- %sum5 = add i32 %sum4, %l2 -- %c = call fastcc i32 @get_int(i32 %sum5) -- ret i32 %c --} -- --; CHECK-LABEL: define i32 @Function2_not_merged(i32 %x) --; CHECK: load --; CHECK: load --; CHECK: load --; CHECK: load --; CHECK: %c = call fastcc i32 @get_int --; CHECK: ret i32 %c --define i32 @Function2_not_merged(i32 %x) { -- %l1 = load i32, i32* @g2, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g3, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g4, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g5, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g1, align 4 -- %sum5 = add i32 %sum4, %l2 -- %c = call fastcc i32 @get_int(i32 %sum5) -- ret i32 %c --} -- --; CHECK-LABEL: define i32 @Function3_merged_with_1(i32 %x) --; CHECK: %1 = tail call i32 @Function1_merged_with_3.Tm(i32 %x, ptr @g2) --; CHECK: ret i32 %1 --define i32 @Function3_merged_with_1(i32 %x) { -- %l1 = load i32, i32* @g2, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g2, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g3, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g4, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g5, align 4 -- %sum5 = add i32 %sum4, %l2 -- %c = call fastcc i32 @get_int(i32 %sum5) -- ret i32 %c --} -- --; CHECK-LABEL: define internal i32 @Function1_merged_with_3.Tm(i32 %0, ptr %1) --; CHECK: load --; CHECK: load --; CHECK: load --; CHECK: load --; CHECK: %c = call fastcc i32 @get_int --; CHECK: ret i32 %c -- --; CHECK-LABEL: define i32 @Function4_not_merged(i32 %x) { --; CHECK: load --; CHECK: load --; CHECK: load --; CHECK: load --; CHECK: %1 = call fastcc i32 @Function1_merged_with_3.Tm(i32 %sum5, ptr @g1) --; CHECK: ret i32 %1 --define i32 @Function4_not_merged(i32 %x) { -- %l1 = load i32, i32* @g1, align 4 -- %sum = add i32 %x, %l1 -- %l2 = load i32, i32* @g3, align 4 -- %sum2 = add i32 %sum, %l2 -- %l3 = load i32, i32* @g4, align 4 -- %sum3 = add i32 %sum2, %l2 -- %l4 = load i32, i32* @g5, align 4 -- %sum4 = add i32 %sum3, %l2 -- %l5 = load i32, i32* @g1, align 4 -- %sum5 = add i32 %sum4, %l2 -- %c = call fastcc i32 @Function1_merged_with_3(i32 %sum5) -- ret i32 %c --} -- -- --; Test a call chain: caller -> callee1 -> callee2. --; Functions should be merged in bottom-up order: callee2, callee1, caller. --; Also check that the calling convention is preserved. -- --; CHECK-LABEL: define fastcc i32 @callee1_a(i32 %x, i32 %y) --; CHECK: %1 = tail call fastcc i32 @callee1_a.Tm(i32 %x, i32 %y, ptr @g1) --; CHECK: ret i32 %1 --define fastcc i32 @callee1_a(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %c = call i32 @callee2_a(i32 %sum2, i32 %y) -- %sum3 = add i32 %sum2, %c -- ret i32 %sum3 --} -- --; CHECK-LABEL: define fastcc i32 @callee1_b(i32 %x, i32 %y) --; CHECK: %1 = tail call fastcc i32 @callee1_a.Tm(i32 %x, i32 %y, ptr @g2) --; CHECK: ret i32 %1 --define fastcc i32 @callee1_b(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %c = call i32 @callee2_b(i32 %sum2, i32 %y) -- %sum3 = add i32 %sum2, %c -- ret i32 %sum3 --} -- --; CHECK-LABEL: define internal fastcc i32 @callee1_a.Tm(i32 %0, i32 %1, ptr %2) --; CHECK: call i32 @callee2_a.Tm(i32 %sum2, i32 %1, ptr %2) --; CHECK: ret -- --; CHECK-NOT: @callee2_a( --define internal i32 @callee2_a(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = sub i32 %sum, %y -- %l = load i32, i32* @g1, align 4 -- %sum3 = add i32 %sum2, %y -- ret i32 %sum3 --} -- --; CHECK-NOT: @callee2_b( --define internal i32 @callee2_b(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = sub i32 %sum, %y -- %l = load i32, i32* @g2, align 4 -- %sum3 = add i32 %sum2, %y -- ret i32 %sum3 --} -- --; CHECK-LABEL: define i32 @caller_a(i32 %x, i32 %y) --; CHECK: %1 = tail call i32 @caller_a.Tm(i32 %x, i32 %y, ptr @g1) --; CHECK: ret i32 %1 --define i32 @caller_a(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %c = call fastcc i32 @callee1_a(i32 %sum2, i32 %y) -- %sum3 = add i32 %sum2, %c -- ret i32 %sum3 --} -- --; CHECK-LABEL: define i32 @caller_b(i32 %x, i32 %y) --; CHECK: %1 = tail call i32 @caller_a.Tm(i32 %x, i32 %y, ptr @g2) --; CHECK: ret i32 %1 --define i32 @caller_b(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %c = call fastcc i32 @callee1_b(i32 %sum2, i32 %y) -- %sum3 = add i32 %sum2, %c -- ret i32 %sum3 --} -- --; CHECK-LABEL: define internal i32 @caller_a.Tm(i32 %0, i32 %1, ptr %2) --; CHECK: call fastcc i32 @callee1_a.Tm(i32 %sum2, i32 %1, ptr %2) --; CHECK: ret -- -- --; Ensure that we do not merge functions that are identical with the --; exception of the order of the incoming blocks to a phi. -- --; CHECK-LABEL: define linkonce_odr hidden i1 @first(i2 %0) --define linkonce_odr hidden i1 @first(i2) { --entry: --; CHECK: switch i2 -- switch i2 %0, label %default [ -- i2 0, label %L1 -- i2 1, label %L2 -- i2 -2, label %L3 -- ] --default: -- unreachable --L1: -- br label %done --L2: -- br label %done --L3: -- br label %done --done: -- %result = phi i1 [ true, %L1 ], [ false, %L2 ], [ false, %L3 ] --; CHECK: ret i1 -- ret i1 %result --} -- --; CHECK-LABEL: define linkonce_odr hidden i1 @second(i2 %0) --define linkonce_odr hidden i1 @second(i2) { --entry: --; CHECK: switch i2 -- switch i2 %0, label %default [ -- i2 0, label %L1 -- i2 1, label %L2 -- i2 -2, label %L3 -- ] --default: -- unreachable --L1: -- br label %done --L2: -- br label %done --L3: -- br label %done --done: -- %result = phi i1 [ true, %L3 ], [ false, %L2 ], [ false, %L1 ] --; CHECK: ret i1 -- ret i1 %result --} -- --; Check self recursive functions -- --; CHECK-LABEL: define internal void @recursive1(i32 %x, i32 %y) --; CHECK: tail call void @recursive1.Tm(i32 %x, i32 %y, ptr @g1, ptr @recursive1) --; CHECK: ret void --define internal void @recursive1(i32 %x, i32 %y) { -- br i1 undef, label %bb1, label %bb2 -- --bb1: -- %l = load i32, i32* @g1, align 4 -- call void @recursive1(i32 %x, i32 %y) -- br label %bb2 -- --bb2: -- ret void --} -- --; CHECK-LABEL: define internal void @recursive2(i32 %x, i32 %y) --; CHECK: tail call void @recursive1.Tm(i32 %x, i32 %y, ptr @g2, ptr @recursive2) --; CHECK: ret void --define internal void @recursive2(i32 %x, i32 %y) { -- br i1 undef, label %bb1, label %bb2 -- --bb1: -- %l = load i32, i32* @g2, align 4 -- call void @recursive2(i32 %x, i32 %y) -- br label %bb2 -- --bb2: -- ret void --} --; CHECK-LABEL: define internal void @recursive1.Tm(i32 %0, i32 %1, ptr %2, ptr %3) --; CHECK: load i32, ptr %2 --; CHECK: call void %3(i32 %0, i32 %1) --; CHECK: ret void -- -- --; CHECK-LABEL: define internal void @another_recursive_func(i32 %x) --; CHECK: tail call void @another_recursive_func.Tm(i32 %x, ptr @g1, ptr @another_recursive_func) --; CHECK: ret void --define internal void @another_recursive_func(i32 %x) { -- br i1 undef, label %bb1, label %bb2 -- --bb1: -- store i32 %x, i32 *@g1, align 4 -- call void @another_recursive_func(i32 %x) -- br label %bb2 -- --bb2: -- ret void --} --; CHECK-NOT: @not_really_recursive( -- --; CHECK-LABEL: define internal void @another_recursive_func.Tm(i32 %0, ptr %1, ptr %2) --; CHECK: store i32 %0, ptr %1 --; CHECK: call void %2(i32 %0) --; CHECK: ret void --define internal void @not_really_recursive(i32 %x) { -- br i1 undef, label %bb1, label %bb2 -- --bb1: -- store i32 %x, i32 *@g2, align 4 -- call void @callee1(i32 %x) -- br label %bb2 -- --bb2: -- ret void --} --; CHECK-NOT: @not_really_recursive( -- --; CHECK-LABEL: define void @call_recursive_funcs(i32 %x) --; CHECK: call void @recursive1(i32 %x, i32 %x) --; CHECK: call void @recursive2(i32 %x, i32 %x) --; CHECK: call void @another_recursive_func(i32 %x) --; CHECK: call void @another_recursive_func.Tm(i32 %x, ptr @g2, ptr @callee1) --; CHECK: ret void --define void @call_recursive_funcs(i32 %x) { -- call void @recursive1(i32 %x, i32 %x) -- call void @recursive2(i32 %x, i32 %x) -- call void @another_recursive_func(i32 %x) -- call void @not_really_recursive(i32 %x) -- ret void --} -- --; Ensure that we do not merge functions which make use of distinct dtrace --; probes. Each call to a dtrace probe must resolve to a unique patchpoint. -- --declare void @"__dtrace_probe$Apple$Probe1$v1$696e74"(i32) local_unnamed_addr -- --; CHECK-LABEL: define i32 @use_dtrace_probe1 --; CHECK: call void @"__dtrace_probe$Apple$Probe1$v1$696e74" --define i32 @use_dtrace_probe1(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %l = load i32, i32* @g1, align 4 -- %sum3 = add i32 %sum2, %y -- tail call void @"__dtrace_probe$Apple$Probe1$v1$696e74"(i32 undef) -- ret i32 %sum3 --} -- --declare void @"__dtrace_probe$Apple$Probe2$v1$696e74"(i32) local_unnamed_addr -- --; CHECK-LABEL: define i32 @use_dtrace_probe2 --; CHECK: call void @"__dtrace_probe$Apple$Probe2$v1$696e74" --define i32 @use_dtrace_probe2(i32 %x, i32 %y) { -- %sum = add i32 %x, %y -- %sum2 = add i32 %sum, %y -- %l = load i32, i32* @g2, align 4 -- %sum3 = add i32 %sum2, %y -- tail call void @"__dtrace_probe$Apple$Probe2$v1$696e74"(i32 undef) -- ret i32 %sum3 --} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/MergeFuncIgnoringConst/merge_with_exception.ll b/llvm/test/Transforms/MergeFuncIgnoringConst/merge_with_exception.ll ---- a/llvm/test/Transforms/MergeFuncIgnoringConst/merge_with_exception.ll -+++ b/llvm/test/Transforms/MergeFuncIgnoringConst/merge_with_exception.ll -@@ -1,190 +0,0 @@ --; RUN: opt -S -enable-aggressive-mergefunc-ignoringconst -passes=mergefunc-ignoring-const %s -o - | FileCheck %s -- --%4 = type opaque --%10 = type opaque --%"struct.SearchSpec::State" = type { %4* } --%"struct.PointerList" = type { i8*, i8*, i8*, i8*, i8* } --%"struct.DynamicCallback" = type { %10* } -- --; CHECK: define ptr @invoke_foo(ptr nocapture readonly %.block_descriptor, ptr %stateWrapper) --; CHECK: %1 = {{.*}}call ptr @invoke_foo.Tm --; CHECK: define ptr @invoke_bar(ptr nocapture readonly %.block_descriptor, ptr %stateWrapper) { --; CHECK: %1 = {{.*}}call ptr @invoke_foo.Tm --; CHECK: define {{.*}}.Tm(ptr nocapture readonly %0, ptr %1, ptr %2, ptr %3) -- --; Function Attrs: minsize optsize ssp uwtable --define i8* @invoke_foo(i8* nocapture readonly %.block_descriptor, i8* %stateWrapper) #1 personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { --entry: -- %state = alloca %"struct.SearchSpec::State", align 8 -- %agg.tmp = alloca %"struct.PointerList", align 8 -- %0 = tail call i8* @llvm.objc.retain(i8* %stateWrapper) #2 -- %1 = bitcast %"struct.SearchSpec::State"* %state to i8* -- call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %1) #2 -- %2 = getelementptr inbounds i8, i8* %stateWrapper, i64 16 -- %3 = bitcast i8* %2 to %"struct.SearchSpec::State"* (i8*)** -- %4 = load %"struct.SearchSpec::State"* (i8*)*, %"struct.SearchSpec::State"* (i8*)** %3, align 8 -- %call.i4 = invoke nonnull align 8 dereferenceable(8) %"struct.SearchSpec::State"* %4(i8* nonnull %stateWrapper) #31 -- to label %invoke.cont unwind label %lpad -- --invoke.cont: ; preds = %entry -- %initialText.i.i = getelementptr inbounds %"struct.SearchSpec::State", %"struct.SearchSpec::State"* %state, i64 0, i32 0 -- %initialText2.i.i = getelementptr inbounds %"struct.SearchSpec::State", %"struct.SearchSpec::State"* %call.i4, i64 0, i32 0 -- %5 = load %4*, %4** %initialText2.i.i, align 8 -- %6 = bitcast %4* %5 to i8* -- %7 = tail call i8* @llvm.objc.retain(i8* %6) #2 -- store %4* %5, %4** %initialText.i.i, align 8 -- %block.capture.addr = getelementptr inbounds i8, i8* %.block_descriptor, i64 32 -- %8 = bitcast i8* %block.capture.addr to i8** -- %9 = load i8*, i8** %8, align 8 -- invoke void @callee2(%"struct.PointerList"* nonnull sret(%"struct.PointerList") align 8 %agg.tmp, i8* %9, i1 zeroext false) #31 -- to label %invoke.cont2 unwind label %lpad1 -- --invoke.cont2: ; preds = %invoke.cont -- %block.capture.addr3 = getelementptr inbounds i8, i8* %.block_descriptor, i64 40 -- %10 = bitcast i8* %block.capture.addr3 to %4** -- %agg.tmp6.sroa.3.0..sroa_idx12 = getelementptr inbounds %"struct.PointerList", %"struct.PointerList"* %agg.tmp, i64 0, i32 3 -- %agg.tmp6.sroa.3.0.copyload = load i8*, i8** %agg.tmp6.sroa.3.0..sroa_idx12, align 8 -- %11 = load %4*, %4** %10, align 8 -- invoke void @callee1(%"struct.SearchSpec::State"* nonnull align 8 dereferenceable(8) %state, %4* %11) #31 -- to label %invoke.cont4 unwind label %lpad.i -- --lpad.i: ; preds = %invoke.cont2 -- %12 = landingpad { i8*, i32 } -- cleanup -- call void @llvm.objc.release(i8* %agg.tmp6.sroa.3.0.copyload) #2 -- %.phi.trans.insert = bitcast %"struct.SearchSpec::State"* %state to i8** -- %.pre = load i8*, i8** %.phi.trans.insert, align 8 -- br label %lpad1.body -- --invoke.cont4: ; preds = %invoke.cont2 -- call void @llvm.objc.release(i8* %agg.tmp6.sroa.3.0.copyload) #2 -- %13 = load %4*, %4** %initialText.i.i, align 8 -- store %4* null, %4** %initialText.i.i, align 8 -- %call78 = call fastcc i8* @callee3(%4* %13) #31 [ "clang.arc.attachedcall"(i8* (i8*)* @llvm.objc.retainAutoreleasedReturnValue) ] -- call void (...) @llvm.objc.clang.arc.noop.use(i8* %call78) #2 -- %14 = bitcast %"struct.SearchSpec::State"* %state to i8** -- %15 = load i8*, i8** %14, align 8 -- call void @llvm.objc.release(i8* %15) #2 -- call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %1) #2 -- call void @llvm.objc.release(i8* nonnull %stateWrapper) #2, !clang.imprecise_release !1 -- %16 = tail call i8* @llvm.objc.autoreleaseReturnValue(i8* %call78) #2 -- ret i8* %call78 -- --lpad: ; preds = %entry -- %17 = landingpad { i8*, i32 } -- cleanup -- br label %ehcleanup -- --lpad1: ; preds = %invoke.cont -- %18 = landingpad { i8*, i32 } -- cleanup -- br label %lpad1.body -- --lpad1.body: ; preds = %lpad1, %lpad.i -- %19 = phi i8* [ %6, %lpad1 ], [ %.pre, %lpad.i ] -- %eh.lpad-body = phi { i8*, i32 } [ %18, %lpad1 ], [ %12, %lpad.i ] -- call void @llvm.objc.release(i8* %19) #2 -- br label %ehcleanup -- --ehcleanup: ; preds = %lpad1.body, %lpad -- %.pn = phi { i8*, i32 } [ %eh.lpad-body, %lpad1.body ], [ %17, %lpad ] -- call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %1) #2 -- call void @llvm.objc.release(i8* nonnull %stateWrapper) #2, !clang.imprecise_release !1 -- resume { i8*, i32 } %.pn --} -- --; Function Attrs: minsize optsize ssp uwtable --define i8* @invoke_bar(i8* nocapture readonly %.block_descriptor, i8* %stateWrapper) #1 personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { --entry: -- %state = alloca %"struct.DynamicCallback", align 8 -- %agg.tmp = alloca %"struct.PointerList", align 8 -- %0 = tail call i8* @llvm.objc.retain(i8* %stateWrapper) #2 -- %1 = bitcast %"struct.DynamicCallback"* %state to i8* -- call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %1) #2 -- %2 = getelementptr inbounds i8, i8* %stateWrapper, i64 16 -- %3 = bitcast i8* %2 to %"struct.DynamicCallback"* (i8*)** -- %4 = load %"struct.DynamicCallback"* (i8*)*, %"struct.DynamicCallback"* (i8*)** %3, align 8 -- %call.i4 = invoke nonnull align 8 dereferenceable(8) %"struct.DynamicCallback"* %4(i8* nonnull %stateWrapper) #31 -- to label %invoke.cont unwind label %lpad -- --invoke.cont: ; preds = %entry -- %call.i.i = getelementptr inbounds %"struct.DynamicCallback", %"struct.DynamicCallback"* %state, i64 0, i32 0 -- %call2.i.i = getelementptr inbounds %"struct.DynamicCallback", %"struct.DynamicCallback"* %call.i4, i64 0, i32 0 -- %5 = load %10*, %10** %call2.i.i, align 8 -- %6 = bitcast %10* %5 to i8* -- %7 = tail call i8* @llvm.objc.retain(i8* %6) #2 -- store %10* %5, %10** %call.i.i, align 8 -- %block.capture.addr = getelementptr inbounds i8, i8* %.block_descriptor, i64 32 -- %8 = bitcast i8* %block.capture.addr to i8** -- %9 = load i8*, i8** %8, align 8 -- invoke void @callee2(%"struct.PointerList"* nonnull sret(%"struct.PointerList") align 8 %agg.tmp, i8* %9, i1 zeroext false) #31 -- to label %invoke.cont2 unwind label %lpad1 -- --invoke.cont2: ; preds = %invoke.cont -- %block.capture.addr3 = getelementptr inbounds i8, i8* %.block_descriptor, i64 40 -- %10 = bitcast i8* %block.capture.addr3 to %10** -- %agg.tmp6.sroa.3.0..sroa_idx12 = getelementptr inbounds %"struct.PointerList", %"struct.PointerList"* %agg.tmp, i64 0, i32 3 -- %agg.tmp6.sroa.3.0.copyload = load i8*, i8** %agg.tmp6.sroa.3.0..sroa_idx12, align 8 -- %11 = load %10*, %10** %10, align 8 -- invoke void @callee5(%"struct.DynamicCallback"* nonnull align 8 dereferenceable(8) %state, %10* %11) #31 -- to label %invoke.cont4 unwind label %lpad.i -- --lpad.i: ; preds = %invoke.cont2 -- %12 = landingpad { i8*, i32 } -- cleanup -- call void @llvm.objc.release(i8* %agg.tmp6.sroa.3.0.copyload) #2 -- %.phi.trans.insert = bitcast %"struct.DynamicCallback"* %state to i8** -- %.pre = load i8*, i8** %.phi.trans.insert, align 8 -- br label %lpad1.body -- --invoke.cont4: ; preds = %invoke.cont2 -- call void @llvm.objc.release(i8* %agg.tmp6.sroa.3.0.copyload) #2 -- %13 = load %10*, %10** %call.i.i, align 8 -- store %10* null, %10** %call.i.i, align 8 -- %call78 = call fastcc i8* @callee4(%10* %13) #31 [ "clang.arc.attachedcall"(i8* (i8*)* @llvm.objc.retainAutoreleasedReturnValue) ] -- call void (...) @llvm.objc.clang.arc.noop.use(i8* %call78) #2 -- %14 = bitcast %"struct.DynamicCallback"* %state to i8** -- %15 = load i8*, i8** %14, align 8 -- call void @llvm.objc.release(i8* %15) #2 -- call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %1) #2 -- call void @llvm.objc.release(i8* nonnull %stateWrapper) #2, !clang.imprecise_release !1 -- %16 = tail call i8* @llvm.objc.autoreleaseReturnValue(i8* %call78) #2 -- ret i8* %call78 -- --lpad: ; preds = %entry -- %17 = landingpad { i8*, i32 } -- cleanup -- br label %ehcleanup -- --lpad1: ; preds = %invoke.cont -- %18 = landingpad { i8*, i32 } -- cleanup -- br label %lpad1.body -- --lpad1.body: ; preds = %lpad1, %lpad.i -- %19 = phi i8* [ %6, %lpad1 ], [ %.pre, %lpad.i ] -- %eh.lpad-body = phi { i8*, i32 } [ %18, %lpad1 ], [ %12, %lpad.i ] -- call void @llvm.objc.release(i8* %19) #2 -- br label %ehcleanup -- --ehcleanup: ; preds = %lpad1.body, %lpad -- %.pn = phi { i8*, i32 } [ %eh.lpad-body, %lpad1.body ], [ %17, %lpad ] -- call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %1) #2 -- call void @llvm.objc.release(i8* nonnull %stateWrapper) #2, !clang.imprecise_release !1 -- resume { i8*, i32 } %.pn --} --declare void @callee1(%"struct.SearchSpec::State"* nonnull align 8 dereferenceable(8), %4*) --declare void @callee2(%"struct.PointerList"* sret(%"struct.PointerList") align 8, i8*, i1 zeroext) --declare i8* @callee3(%4* %state.coerce) --declare i8* @callee4(%10* %state.coerce) --declare void @callee5(%"struct.DynamicCallback"* nonnull align 8 dereferenceable(8), %10*) --declare i32 @__gxx_personality_v0(...) --declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) --declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) --declare i8* @llvm.objc.autoreleaseReturnValue(i8*) --declare void @llvm.objc.clang.arc.noop.use(...) --declare void @llvm.objc.release(i8*) --declare i8* @llvm.objc.retain(i8*) --declare i8* @llvm.objc.retainAutoreleasedReturnValue(i8*) -- --!1 = !{} diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 967a6224edd455..5015e65c2d7640 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "bcb685e11945946335c2dc6265779f0226491b49" - LLVM_SHA256 = "dbeb744a9656b7e7035b350ea6b2d303db26da8da000bc85a13f517c5a13195b" + LLVM_COMMIT = "67d7903262ce5c35bb23d599040dff29b9d7759e" + LLVM_SHA256 = "49062de6219c30871d4dd11c047ed1d70783c345adaacca2199c0830849b06a7" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch old mode 100644 new mode 100755 index be1c1f0838e9d7..a476720fd2dbd6 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,39 +1,14 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel -@@ -279,6 +279,24 @@ - ) - - cc_library( -+ name = "experimental_ops", -+ srcs = [ -+ "stablehlo/dialect/ExperimentalOps.cpp", -+ ], -+ hdrs = [ -+ "stablehlo/dialect/ExperimentalOps.h", -+ ], -+ strip_include_prefix = ".", -+ deps = [ -+ ":stablehlo_ops", -+ "@llvm-project//llvm:Support", -+ "@llvm-project//mlir:FuncDialect", -+ "@llvm-project//mlir:IR", -+ "@llvm-project//mlir:Support", -+ ], -+) -+ -+cc_library( - name = "interpreter_ops", - srcs = [ - "stablehlo/reference/InterpreterOps.cpp", -@@ -780,6 +798,7 @@ +@@ -890,6 +890,7 @@ + hdrs = [ + "stablehlo/transforms/MapStablehloToVhlo.h", + "stablehlo/transforms/Passes.h", ++ "stablehlo/transforms/StablehloRefineShapes.h", + ], + strip_include_prefix = ".", deps = [ - ":base", - ":chlo_ops", -+ ":experimental_ops", - ":stablehlo_ops", - ":stablehlo_ops_inc_gen", - ":stablehlo_pass_inc_gen", diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -181,32 +156,198 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup -diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir ---- stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir -+++ stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir -@@ -19,6 +19,7 @@ - func.func @iota_dimension_0() -> tensor<4x8xf32> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> -+ // CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array} - %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>) - return %0 : tensor<4x8xf32> - } -@@ -27,6 +28,7 @@ - func.func @iota_dimension_1() -> tensor<4x8xi32> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> -+ // CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array} - %0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>) - return %0 : tensor<4x8xi32> - } -diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp ---- stablehlo/stablehlo/dialect/Base.cpp -+++ stablehlo/stablehlo/dialect/Base.cpp -@@ -600,5 +600,18 @@ - return UnrankedTensorType::get(components.getElementType()); - } +diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt +--- stablehlo/stablehlo/CMakeLists.txt ++++ stablehlo/stablehlo/CMakeLists.txt +@@ -15,6 +15,7 @@ + add_subdirectory(api) + add_subdirectory(conversions) + add_subdirectory(dialect) ++add_subdirectory(experimental) + add_subdirectory(integrations) + add_subdirectory(reference) + add_subdirectory(tests) +diff --ruN a/stablehlo/stablehlo/api/PortableApi.h b/stablehlo/stablehlo/api/PortableApi.h +--- stablehlo/stablehlo/api/PortableApi.h ++++ stablehlo/stablehlo/api/PortableApi.h +@@ -27,7 +27,8 @@ + /// Return the current version for portable API. + /// Increments on all meaningful changes to this file. +-inline int64_t getApiVersion() { return 4; } ++/// Or on large breaking source changes that are difficult to integrate. ++inline int64_t getApiVersion() { return 5; } + + // Get the current StableHLO version. + // +diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel +--- stablehlo/stablehlo/experimental/BUILD.bazel ++++ stablehlo/stablehlo/experimental/BUILD.bazel +@@ -0,0 +1,114 @@ ++# Copyright 2023 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") ++ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++cc_library( ++ name = "experimental_base", ++ srcs = [ ++ "dialect/Base.cpp", ++ ], ++ hdrs = [ ++ "dialect/Base.h", ++ ], ++ deps = [ ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "experimental_stablehlo_ops", ++ srcs = [ ++ "dialect/StablehloOps.cpp", ++ ], ++ hdrs = [ ++ "dialect/StablehloOps.h", ++ ], ++ deps = [ ++ ":experimental_base", ++ "//:stablehlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "experimental_stablehlo_pass_inc_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-pass-decls", ++ ], ++ "transforms/Passes.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "transforms/Passes.td", ++ deps = ["@llvm-project//mlir:PassBaseTdFiles"], ++) ++ ++cc_library( ++ name = "experimental_stablehlo_passes", ++ srcs = [ ++ "transforms/StablehloCanonicalizeDynamism.cpp", ++ "transforms/StablehloRefineShapes.cpp", ++ ], ++ hdrs = [ ++ "transforms/Passes.h", ++ ], ++ deps = [ ++ ":experimental_stablehlo_ops", ++ ":experimental_stablehlo_pass_inc_gen", ++ "//:base", ++ "//:chlo_ops", ++ "//:stablehlo_ops", ++ "//:stablehlo_ops_inc_gen", ++ "//:stablehlo_passes", ++ "//:stablehlo_type_inference", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:Transforms", ++ ], ++) ++ ++cc_binary( ++ name = "experimental-stablehlo-opt", ++ srcs = [ ++ "tools/StablehloOptMain.cpp", ++ ], ++ deps = [ ++ ":experimental_stablehlo_passes", ++ "//:interpreter_ops", ++ "//:register", ++ "//:stablehlo_passes", ++ "//:test_utils", ++ "//:tosa_passes", ++ "@llvm-project//mlir:AllExtensions", ++ "@llvm-project//mlir:AllPassesAndDialects", ++ "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:TosaDialect", ++ ], ++) +diff --ruN a/stablehlo/stablehlo/experimental/CMakeLists.txt b/stablehlo/stablehlo/experimental/CMakeLists.txt +--- stablehlo/stablehlo/experimental/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/CMakeLists.txt +@@ -0,0 +1,18 @@ ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++add_subdirectory(dialect) ++add_subdirectory(tests) ++add_subdirectory(tools) ++add_subdirectory(transforms) +diff --ruN a/stablehlo/stablehlo/experimental/dialect/Base.cpp b/stablehlo/stablehlo/experimental/dialect/Base.cpp +--- stablehlo/stablehlo/experimental/dialect/Base.cpp ++++ stablehlo/stablehlo/experimental/dialect/Base.cpp +@@ -0,0 +1,39 @@ ++/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++ Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "stablehlo/experimental/dialect/Base.h" ++ ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypes.h" ++ ++namespace mlir { ++namespace hlo { ++ +DenseIntElementsAttr getPaddingAttr(MLIRContext* context, + ArrayRef values) { + return DenseIntElementsAttr::get( @@ -220,50 +361,97 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/ + return getPaddingAttr(builder->getContext(), values); +} + - } // namespace hlo - } // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h ---- stablehlo/stablehlo/dialect/Base.h -+++ stablehlo/stablehlo/dialect/Base.h -@@ -194,6 +194,10 @@ - - ShapedType createShapedType(ShapedTypeComponents components); - ++} // namespace hlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/experimental/dialect/Base.h b/stablehlo/stablehlo/experimental/dialect/Base.h +--- stablehlo/stablehlo/experimental/dialect/Base.h ++++ stablehlo/stablehlo/experimental/dialect/Base.h +@@ -0,0 +1,35 @@ ++/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. ++ Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_EXPERIMENTAL_DIALECT_BASE_H ++#define STABLEHLO_EXPERIMENTAL_DIALECT_BASE_H ++ ++#include "llvm/ADT/ArrayRef.h" ++#include "mlir/IR/Builders.h" ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/MLIRContext.h" ++ ++namespace mlir { ++namespace hlo { ++ +DenseIntElementsAttr getPaddingAttr(MLIRContext *context, + ArrayRef value); +DenseIntElementsAttr getPaddingAttr(Builder *builder, ArrayRef value); + - // This interface is implemented by both StableHLO and MHLO dialects - // and is used as the foundation for sharing verification, type inference and - // prettyprinting logic between them. -diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt ---- stablehlo/stablehlo/dialect/CMakeLists.txt -+++ stablehlo/stablehlo/dialect/CMakeLists.txt -@@ -77,6 +77,20 @@ - target_include_directories(ChloOps INTERFACE - $ - $ ++} // namespace hlo ++} // namespace mlir ++ ++#endif // STABLEHLO_EXPERIMENTAL_DIALECT_BASE_H +diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt +--- stablehlo/stablehlo/experimental/dialect/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/dialect/CMakeLists.txt +@@ -0,0 +1,42 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++add_mlir_library(ExperimentalStablehloBase ++ PARTIAL_SOURCES_INTENDED ++ Base.cpp ++ ++ LINK_LIBS PUBLIC ++ MLIRIR +) + -+add_mlir_dialect_library(ExperimentalOps ++add_mlir_dialect_library(ExperimentalStablehloOps + PARTIAL_SOURCES_INTENDED -+ ExperimentalOps.cpp ++ StablehloOps.cpp + + DEPENDS + StablehloOpsIncGen + + LINK_LIBS PUBLIC ++ ExperimentalStablehloBase + MLIRFuncDialect + MLIRIR + MLIRSupport + StablehloOps - ) - - add_mlir_dialect_library(StablehloRegister -diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stablehlo/dialect/ExperimentalOps.cpp ---- stablehlo/stablehlo/dialect/ExperimentalOps.cpp -+++ stablehlo/stablehlo/dialect/ExperimentalOps.cpp -@@ -0,0 +1,504 @@ ++) ++ ++target_include_directories(ExperimentalStablehloOps INTERFACE ++ $ ++ $ ++) +diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +@@ -0,0 +1,615 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -279,8 +467,9 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh +limitations under the License. +==============================================================================*/ + -+#include "stablehlo/dialect/ExperimentalOps.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" + ++#include +#include + +#include "llvm/ADT/ArrayRef.h" @@ -293,6 +482,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + +namespace mlir { +namespace stablehlo { ++namespace experimental { + +LogicalResult DynamicReduceWindowOpAdaptor::verify() { + // Before checking the constraints inherited from ReduceWindowOp, @@ -306,8 +496,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + // api_version and backend_config have default values. + // call_target_name should be "stablehlo.dynamic_reduce_window". + // called_computations carries the body. -+ if (attr.getName() != "api_version" && -+ attr.getName() != "backend_config" && ++ if (attr.getName() != "api_version" && attr.getName() != "backend_config" && + attr.getName() != "call_target_name" && + attr.getName() != "called_computations") + return op_.emitError() @@ -688,8 +877,8 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + + // dynamic_top_k_i2 + auto kType = k.getType().dyn_cast(); -+ if (!kType || !kType.hasRank() || -+ kType.getRank() != 0 || !kType.getElementType().isIntOrIndex()) ++ if (!kType || !kType.hasRank() || kType.getRank() != 0 || ++ !kType.getElementType().isIntOrIndex()) + return op_.emitError() + << "expects k (operand #1) " + << "to be a 0-dimensional tensor of integer or index type"; @@ -751,7 +940,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + return op_.getInputs()[1].cast>(); +} + -+ +TypedValue DynamicTopKOpAdaptor::getValues() { + return op_.getResults()[0].cast>(); +} @@ -760,18 +948,129 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + return op_.getResults()[1].cast>(); +} + -+std::optional getDynamicTopKOp( -+ CustomCallOp op) { ++std::optional getDynamicTopKOp(CustomCallOp op) { + if (op.getCallTargetName() != "stablehlo.dynamic_top_k") return {}; + return DynamicTopKOpAdaptor(op); +} + ++LogicalResult TopKOpAdaptor::verify() { ++ if (op_->getNumOperands() != 1) ++ return op_.emitError("expects size(operands) = 1"); ++ if (op_->getNumResults() != 2) ++ return op_.emitError("expects size(results) = 2"); ++ if (!op_.getBackendConfig().empty()) ++ return op_.emitError() << "expects an empty backend_config"; ++ if (op_.getCallTargetName() != "mhlo.topk") ++ return op_.emitError() << "expects @mhlo.topk"; ++ ++ auto operand = op_.getInputs()[0]; ++ auto values = op_.getResults()[0]; ++ auto indices = op_.getResults()[1]; ++ DictionaryAttr topkAttributes = ++ op_->getAttrOfType("mhlo.attributes"); ++ if (!topkAttributes) { ++ return op_.emitError() ++ << "mhlo.attributes missing or not a dictionary attribute"; ++ } ++ ++ IntegerAttr k_attr = topkAttributes.get("k").dyn_cast_or_null(); ++ if (!k_attr) { ++ return op_.emitError() << "mhlo.attributes.k not present or not an integer"; ++ } ++ int64_t k = k_attr.getInt(); ++ ++ // mhlo.topk_c5 ++ if (k < 0) return op_.emitError() << "expects k >= 0"; ++ ++ // mhlo.topk_i1 ++ auto operandType = operand.getType().dyn_cast(); ++ if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 || ++ !operandType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects operand #0 " ++ << "to be a tensor of integer or floating-point type " ++ << "of rank at least 1"; ++ ++ // mhlo.topk_o1 ++ auto valuesType = values.getType().dyn_cast(); ++ if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 || ++ !valuesType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects values (result #0) " ++ << "to be a tensor of integer or floating-point type " ++ << "of rank at least 1"; ++ ++ // mhlo.topk_o2 ++ auto indicesType = indices.getType().dyn_cast(); ++ if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 || ++ !indicesType.getElementType().isSignlessInteger(32)) ++ return op_.emitError() << "expects indices (result #1) " ++ << "to be a tensor of si32 of rank at least 1"; ++ ++ // mhlo.topk_c1 && mhlo.topk_c2 ++ auto operandLastDim = operandType.getRank() - 1; ++ SmallVector expectedValuesShape(operandType.getShape()); ++ expectedValuesShape[operandLastDim] = k; ++ if (failed(verifyCompatibleShape(expectedValuesShape, valuesType.getShape()))) ++ return op_.emitError() << "expects the values shape to match the operand " ++ "shape in all but the last dimension, and " ++ "that the last dimension of the values shape " ++ "has a size k"; ++ ++ // mhlo.topk_c3 ++ if (valuesType.getElementType() != operandType.getElementType()) ++ return op_.emitError() ++ << "expects the values element type to be the same as the operand " ++ << "element type"; ++ ++ // mhlo.topk_c4 ++ if (failed( ++ verifyCompatibleShape(indicesType.getShape(), valuesType.getShape()))) ++ return op_.emitError() ++ << "expects the indices shape to match the values shape"; ++ ++ return success(); ++} ++ ++TypedValue TopKOpAdaptor::getOperand() { ++ return op_.getInputs()[0].cast>(); ++} ++ ++TypedValue TopKOpAdaptor::getValues() { ++ return op_.getResults()[0].cast>(); ++} ++ ++TypedValue TopKOpAdaptor::getIndices() { ++ return op_.getResults()[1].cast>(); ++} ++ ++int64_t TopKOpAdaptor::getK() { ++ DictionaryAttr topkAttributes = ++ op_->getAttrOfType("mhlo.attributes"); ++ return topkAttributes.get("k").cast().getInt(); ++} ++ ++bool TopKOpAdaptor::getLargest() { ++ DictionaryAttr topkAttributes = ++ op_->getAttrOfType("mhlo.attributes"); ++ IntegerAttr largest = ++ topkAttributes.get("largest").dyn_cast_or_null(); ++ ++ return (!largest) ? true : largest.getInt(); ++} ++ ++std::optional getTopKOp(CustomCallOp op) { ++ if (op.getCallTargetName() != "mhlo.topk") return {}; ++ return TopKOpAdaptor(op); ++} ++ ++} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo/dialect/ExperimentalOps.h ---- stablehlo/stablehlo/dialect/ExperimentalOps.h -+++ stablehlo/stablehlo/dialect/ExperimentalOps.h -@@ -0,0 +1,227 @@ +diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.h b/stablehlo/stablehlo/experimental/dialect/StablehloOps.h +--- stablehlo/stablehlo/experimental/dialect/StablehloOps.h ++++ stablehlo/stablehlo/experimental/dialect/StablehloOps.h +@@ -0,0 +1,299 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -787,8 +1086,8 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +limitations under the License. +==============================================================================*/ + -+#ifndef STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -+#define STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H ++#ifndef STABLEHLO_EXPERIMENTAL_DIALECT_STABLEHLO_OPS_H ++#define STABLEHLO_EXPERIMENTAL_DIALECT_STABLEHLO_OPS_H + +// This file supports XLA-specific experiments with the StableHLO opset. +// These experiments are not yet ready to be upstreamed to openxla/stablehlo @@ -805,9 +1104,11 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/experimental/dialect/Base.h" + +namespace mlir { +namespace stablehlo { ++namespace experimental { + +// The DynamicReduceWindowOp experiment provides a dynamic version of +// ReduceWindowOp. Once the dynamism RFC is figured out, we expect to have an @@ -995,55 +1296,253 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +// "stablehlo.dynamic_top_k". +std::optional getDynamicTopKOp(CustomCallOp op); + ++/////////////////// ++// MHLO Op Wrappers ++// There are some ops in MHLO which have experimental support in StableHLO ++// programs by representing them as custom_calls with the target `mhlo.op_name`. ++// The level of support of these ops is similar to the other custom_calls in ++// this file. Generally these ops will be added to StableHLO and their ++// experimental support can be deprecated in favor of op's type inference. ++/////////////////// ++ ++// The TopK experiment provides a StableHLO adapter to MHLO TopKOp. ++// In the future we expect stablehlo.top_k to be added which will use the same ++// refinement rules. ++// ++// Within this experiment, TopKOp is represented via the serialized MHLO ++// `stablehlo.custom_call @mhlo.topk` custom call. ++// ++// The semantics of experimental TopKOp are inherited from the semantics of ++// mhlo.topk. ++// ++// #### Inputs ++// ++// | Label | Name | Type | ++// |-------|-----------------|----------------------------------------------| ++// | (I1) | `operand` | tensor of integer or floating-point type | ++// | (I2) | `k` | constant of type si64 | ++// | (I3) | `largest` | constant of type i1 | ++// ++// #### Outputs ++// ++// | Name | Type | ++// |----------------|------------------------------------------| ++// | `values` | tensor of integer or floating-point type | ++// | `indices` | tensor of si32 type | ++// ++// #### Constraints ++// ++// * (C1) `shape(values)[:-1] = shape(operand)[:-1]` ++// * (C2) `shape(values)[-1] = k` ++// * (C3) `element_type(values) = element_type(operand)` ++// * (C4) `shape(indices) = shape(values)` ++// * (C5) `k >= 0` ++// ++class TopKOpAdaptor { ++ public: ++ TopKOpAdaptor(CustomCallOp op) : op_(op) {} ++ operator Operation*() { return op_; } ++ Operation* operator->() { return op_; } ++ ++ // These accessors assume that the operation is well-formed (i.e. that it ++ // can pass verification). ++ TypedValue getOperand(); ++ TypedValue getValues(); ++ TypedValue getIndices(); ++ int64_t getK(); ++ bool getLargest(); ++ ++ // Verifies the constraints documented above. ++ // Emits errors if errors are detected. ++ LogicalResult verify(); ++ ++ private: ++ CustomCallOp op_; ++}; ++ ++// Wraps a custom call in a TopKOpAdaptor. ++// Fails if the call_target_name of the custom call doesn't match ++// "mhlo.topk". ++std::optional getTopKOp(CustomCallOp op); ++ ++} // namespace experimental +} // namespace stablehlo +} // namespace mlir + -+#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp ---- stablehlo/stablehlo/dialect/StablehloOps.cpp -+++ stablehlo/stablehlo/dialect/StablehloOps.cpp -@@ -1543,6 +1543,7 @@ - p << " across dimensions = ["; - llvm::interleaveComma(getDimensions().getValues(), p); - p << "]"; -+ p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); - p << " : "; - p.printFunctionalType(*this); - } else { -@@ -1705,6 +1706,7 @@ - if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || - parser.parseEqual() || - parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || -+ parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(reduceOpFnType) || - parser.parseOptionalLocationSpecifier(explicitLoc)) - return failure(); -diff --ruN a/stablehlo/stablehlo/tests/print_reduce.mlir b/stablehlo/stablehlo/tests/print_reduce.mlir ---- stablehlo/stablehlo/tests/print_reduce.mlir -+++ stablehlo/stablehlo/tests/print_reduce.mlir -@@ -168,3 +168,15 @@ - - func.return %0: tensor<4xf32> - } ++#endif // STABLEHLO_EXPERIMENTAL_DIALECT_STABLEHLO_OPS_H +diff --ruN a/stablehlo/stablehlo/experimental/tests/BUILD.bazel b/stablehlo/stablehlo/experimental/tests/BUILD.bazel +--- stablehlo/stablehlo/experimental/tests/BUILD.bazel ++++ stablehlo/stablehlo/experimental/tests/BUILD.bazel +@@ -0,0 +1,59 @@ ++# Copyright 2023 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++load("@bazel_skylib//rules:expand_template.bzl", "expand_template") ++load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") ++ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) + -+// The test case makes sure any custom attrs set on the reduce-op are -+// printed/parsed when pretty-printed. ++# Equivalent of configure_lit_site_cfg from CMakeLists.txt. ++expand_template( ++ name = "lit_site_cfg_py_gen", ++ testonly = True, ++ out = "lit.site.cfg.py", ++ substitutions = { ++ "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", ++ "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"), ++ "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ }, ++ template = "lit.site.cfg.py.in", ++) + -+// CHECK-LABEL: func @pretty_print_with_custom_attr -+// CHECK: applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} ++# Equivalent of add_lit_testsuite from CMakeLists.txt. ++[ ++ lit_test( ++ name = "%s.test" % src, ++ size = "small", ++ srcs = [src], ++ data = [ ++ "lit.cfg.py", ++ "lit.site.cfg.py", ++ "//:stablehlo-opt", ++ "//:stablehlo-translate", ++ "//stablehlo/experimental:experimental-stablehlo-opt", ++ "@llvm-project//llvm:FileCheck", ++ "@llvm-project//llvm:not", ++ ] + glob(["%s.bc" % src]), ++ tags = ["stablehlo_tests"], ++ ) ++ for src in glob(["**/*.mlir"]) ++] ++ ++test_suite( ++ name = "experimental_stablehlo_tests", ++ tags = ["experimental_stablehlo_tests"], ++) +diff --ruN a/stablehlo/stablehlo/experimental/tests/CMakeLists.txt b/stablehlo/stablehlo/experimental/tests/CMakeLists.txt +--- stablehlo/stablehlo/experimental/tests/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/tests/CMakeLists.txt +@@ -0,0 +1,29 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++configure_lit_site_cfg( ++ ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ++ ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py ++ MAIN_CONFIG ++ ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py ++) ++add_lit_testsuite(check-experimental-stablehlo-tests "Running the experimental/tests/ suite" ++ ${CMAKE_CURRENT_BINARY_DIR} ++ DEPENDS ++ FileCheck ++ experimental-stablehlo-opt ++ stablehlo-translate ++) ++add_dependencies(check-stablehlo-quick check-experimental-stablehlo-tests) +diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.cfg.py b/stablehlo/stablehlo/experimental/tests/lit.cfg.py +--- stablehlo/stablehlo/experimental/tests/lit.cfg.py ++++ stablehlo/stablehlo/experimental/tests/lit.cfg.py +@@ -0,0 +1,42 @@ ++"""Lit configuration to drive test in this repo.""" ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++# -*- Python -*- ++# pylint: disable=undefined-variable ++ ++import os ++ ++import lit.formats ++from lit.llvm import llvm_config ++ ++# Populate Lit configuration with the minimal required metadata. ++# Some metadata is populated in lit.site.cfg.py.in. ++config.name = 'STABLEHLO_TESTS_SUITE' ++config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) ++config.suffixes = ['.mlir'] ++config.test_source_root = os.path.dirname(__file__) ++ ++# Make LLVM and StableHLO tools available in RUN directives ++tools = [ ++ 'FileCheck', ++ 'experimental-stablehlo-opt', ++ 'stablehlo-translate', ++ 'not', ++] ++tool_dirs = [ ++ config.llvm_tools_dir, ++ config.stablehlo_tools_dir, ++] ++llvm_config.add_tool_substitutions(tools, tool_dirs) +diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in b/stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in +--- stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in ++++ stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in +@@ -0,0 +1,21 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++@LIT_SITE_CFG_IN_HEADER@ ++ ++import lit.llvm ++lit.llvm.initialize(lit_config, config) ++config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" ++config.stablehlo_tools_dir = "@STABLEHLO_TOOLS_DIR@" ++lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@" + "/stablehlo/experimental/tests/lit.cfg.py") +diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir +--- stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir ++++ stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir +@@ -0,0 +1,344 @@ ++// RUN: experimental-stablehlo-opt --experimental-stablehlo-canonicalize-dynamism --split-input-file --verify-diagnostics %s | FileCheck %s + -+func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> { -+ %0 = stablehlo.constant dense<0.000000e+00> : tensor -+ %1 = stablehlo.reduce(%arg0 init: %0) applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor) -> tensor<2x13xf32> -+ return %1 : tensor<2x13xf32> -+} -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir ---- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -+++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -@@ -426,6 +426,172 @@ - - // ----- - +// CHECK-LABEL: func @dynamic_reduce_window_success_static_result_type +func.func @dynamic_reduce_window_success_static_result_type(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<2x2xf32> { + // CHECK-NOT: stablehlo.dynamic_reduce_window @@ -1209,17 +1708,6 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st +} + +// ----- -+ - // CHECK-LABEL: func @dynamic_reshape_success - func.func @dynamic_reshape_success(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NOT: stablehlo.dynamic_reshape -@@ -452,6 +618,185 @@ - %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x?xf32> - return %1 : tensor<1x?xf32> -+} -+ -+// ----- + +// CHECK-LABEL: func @dynamic_rng_bit_generator_success +func.func @dynamic_rng_bit_generator_success(%arg0: tensor<2xui64>) -> tensor<1x4xf32> { @@ -1396,16 +1884,13 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st + %k = stablehlo.constant dense<3> : tensor + %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<4xi32>) + return %1#0, %1#1 : tensor<3xf32>, tensor<4xi32> - } - - // ----- -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -607,12 +607,55 @@ - - // ----- - ++} +diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir +@@ -0,0 +1,152 @@ ++// RUN: experimental-stablehlo-opt --experimental-stablehlo-refine-shapes --split-input-file --verify-diagnostics %s | FileCheck %s ++ +// CHECK-LABEL: @main +func.func @main(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<*xf32> { + // CHECK: stablehlo.dynamic_reduce_window{{.*}} -> tensor<2x2xf32> @@ -1426,16 +1911,6 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/ +} + +// ----- -+ - // CHECK-LABEL: @refine_dynamic_reshape - func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<*xf32> { - // CHECK: stablehlo.dynamic_reshape{{.*}} -> tensor<1x4xf32> - %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -+} -+ -+// ----- + +// CHECK-LABEL: @refine_dynamic_rng_bit_generator +func.func @refine_dynamic_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor, tensor<*xf32>) { @@ -1455,36 +1930,374 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/ + %k = stablehlo.constant dense<4> : tensor + %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor, tensor) + return %1#0, %1#1 : tensor, tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -25,6 +25,7 @@ - For example, if the output_shape operand of DynamicReshapeOp is a constant - value, then the operation can be transformed to ReshapeOp. - }]; ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_topk ++func.func @refine_mhlo_topk(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // CHECK: mhlo.topk{{.*}} -> (tensor<5x4xf32>, tensor<5x4xi32>) ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_too_many_operands ++func.func @refine_mhlo_error_too_many_operands(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects size(operands) = 1}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0, %arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>, tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_too_few_results ++func.func @refine_mhlo_error_too_few_results(%arg0: tensor<5x16xf32>) -> (tensor) { ++ // expected-error@+1{{expects size(results) = 2}} ++ %0 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor) ++ return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_wrong_output_1_type ++func.func @refine_mhlo_error_wrong_output_1_type(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects values (result #0) to be a tensor of integer or floating-point type of rank at least 1}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_wrong_output_2_type ++func.func @refine_mhlo_error_wrong_output_2_type(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects indices (result #1) to be a tensor of si32 of rank at least 1}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c1_wrong_output_shape ++func.func @refine_mhlo_error_c1_wrong_output_shape(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the values shape to match the operand}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c2_last_dim_not_k ++func.func @refine_mhlo_error_c2_last_dim_not_k(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the values shape to match the operand}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c3_wrong_output_type ++func.func @refine_mhlo_error_c3_wrong_output_type(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the values element type to be the same as the operand element type}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c4_outputs_shape_mismatch ++func.func @refine_mhlo_error_c4_outputs_shape_mismatch(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the indices shape to match the values shape}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c5_negative_k ++func.func @refine_mhlo_error_c5_negative_k(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects k >= 0}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = -4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} +diff --ruN a/stablehlo/stablehlo/experimental/tools/CMakeLists.txt b/stablehlo/stablehlo/experimental/tools/CMakeLists.txt +--- stablehlo/stablehlo/experimental/tools/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/tools/CMakeLists.txt +@@ -0,0 +1,41 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++set(LLVM_OPTIONAL_SOURCES ++ StablehloOptMain.cpp ++) ++ ++# stablehlo-opt ++get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) ++get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) ++get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) ++set(LIBS ++ ${dialect_libs} ++ ${conversion_libs} ++ ${extension_libs} ++ ExperimentalStablehloPasses ++ MLIROptLib ++ StablehloRegister ++ StablehloTestUtils ++ StablehloPasses ++ InterpreterOps ++ StablehloTOSATransforms ++ ) ++add_llvm_executable(experimental-stablehlo-opt StablehloOptMain.cpp) ++llvm_update_compile_flags(experimental-stablehlo-opt) ++target_link_libraries(experimental-stablehlo-opt PRIVATE ${LIBS}) ++ ++mlir_check_all_link_libraries(experimental-stablehlo-opt) ++ +diff --ruN a/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp b/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp +--- stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp ++++ stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp +@@ -0,0 +1,46 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "mlir/Dialect/Tosa/IR/TosaOps.h" ++#include "mlir/Dialect/Tosa/Transforms/Passes.h" ++#include "mlir/InitAllDialects.h" ++#include "mlir/InitAllExtensions.h" ++#include "mlir/InitAllPasses.h" ++#include "mlir/Tools/mlir-opt/MlirOptMain.h" ++#include "stablehlo/conversions/tosa/transforms/Passes.h" ++#include "stablehlo/dialect/Register.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/reference/InterpreterOps.h" ++#include "stablehlo/tests/TestUtils.h" ++#include "stablehlo/transforms/Passes.h" ++ ++int main(int argc, char **argv) { ++ mlir::registerAllPasses(); ++ mlir::hlo::registerAllTestPasses(); ++ mlir::stablehlo::registerPassPipelines(); ++ mlir::stablehlo::registerPasses(); ++ mlir::stablehlo::experimental::registerPasses(); ++ mlir::tosa::registerStablehloLegalizeToTosaPassPass(); ++ mlir::tosa::registerStablehloPrepareForTosaPassPass(); ++ ++ mlir::DialectRegistry registry; ++ mlir::registerAllDialects(registry); ++ mlir::registerAllExtensions(registry); ++ mlir::stablehlo::registerAllDialects(registry); ++ registry.insert(); ++ ++ return failed( ++ mlir::MlirOptMain(argc, argv, "Experimental StableHLO optimizer driver\n", registry)); ++} +diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +--- stablehlo/stablehlo/experimental/transforms/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +@@ -0,0 +1,39 @@ ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++set(LLVM_TARGET_DEFINITIONS Passes.td) ++mlir_tablegen(Passes.h.inc -gen-pass-decls) ++add_public_tablegen_target(ExperimentalPassesIncGen) ++ ++add_mlir_dialect_library(ExperimentalStablehloPasses ++ PARTIAL_SOURCES_INTENDED ++ StablehloCanonicalizeDynamism.cpp ++ StablehloRefineShapes.cpp ++ ++ DEPENDS ++ ExperimentalPassesIncGen ++ ++ LINK_LIBS PUBLIC ++ ChloOps ++ MLIRFuncDialect ++ MLIRIR ++ MLIRInferTypeOpInterface ++ MLIRSupport ++ MLIRTransformUtils ++ ExperimentalStablehloOps ++ StablehloBase ++ StablehloOps ++ StablehloPasses ++ StablehloTypeInference ++) +diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/stablehlo/experimental/transforms/Passes.h +--- stablehlo/stablehlo/experimental/transforms/Passes.h ++++ stablehlo/stablehlo/experimental/transforms/Passes.h +@@ -0,0 +1,37 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_EXPERIMENTAL_TRANSFORMS_PASSES_H ++#define STABLEHLO_EXPERIMENTAL_TRANSFORMS_PASSES_H ++ ++#include ++ ++#include "mlir/Pass/Pass.h" ++#include "mlir/Transforms/DialectConversion.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS ++#define GEN_PASS_DECL_STABLEHLOREFINESHAPESPASS ++#define GEN_PASS_REGISTRATION ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir ++ ++#endif // STABLEHLO_EXPERIMENTAL_TRANSFORMS_PASSES_H +diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/stablehlo/experimental/transforms/Passes.td +--- stablehlo/stablehlo/experimental/transforms/Passes.td ++++ stablehlo/stablehlo/experimental/transforms/Passes.td +@@ -0,0 +1,31 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++include "mlir/Pass/PassBase.td" ++ ++def StablehloCanonicalizeDynamismPass : Pass<"experimental-stablehlo-canonicalize-dynamism", "func::FuncOp"> { ++ let summary = "(Experimental) Canonicalizes dynamic StableHLO ops into static ops."; ++ let description = [{ ++ Experimental version of the --stablehlo-canonicalize-dynamism pass. ++ }]; + let dependentDialects = ["mlir::chlo::ChloDialect"]; - } - - def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ---- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -+++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -@@ -24,6 +24,8 @@ - #include "mlir/Interfaces/InferTypeOpInterface.h" - #include "mlir/Support/LogicalResult.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++} ++ ++def StablehloRefineShapesPass : Pass<"experimental-stablehlo-refine-shapes", "ModuleOp"> { ++ let summary = "(Experimental) Refines shapes across a StableHLO program."; ++ let description = [{ ++ Experimental version of the --stablehlo-refine-shapes pass. ++ }]; ++} +diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +@@ -0,0 +1,167 @@ ++/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. ++ Copyright 2023 The StableHLO Authors. ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include ++ ++#include "llvm/ADT/STLExtras.h" ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/ChloOps.h" -+#include "stablehlo/dialect/ExperimentalOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/transforms/Passes.h" - -@@ -198,6 +200,54 @@ - } - }; - ++#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DEF_STABLEHLOCANONICALIZEDYNAMISMPASS ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++namespace { ++ +struct CanonicalizeDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -1532,17 +2345,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/ + return success(); + } +}; -+ - struct CanonicalizeDynamicReshapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; -@@ -210,6 +260,56 @@ - if (!op.getType().hasStaticShape()) - return rewriter.notifyMatchFailure(op, "expected static result type"); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); -+ return success(); -+ } -+}; + +struct CanonicalizeDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { @@ -1590,35 +2392,84 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/ + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getOperand(), k[0]); - return success(); - } - }; -@@ -320,7 +420,10 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++ return success(); ++ } ++}; ++ ++struct StablehloCanonicalizeDynamismPass ++ : public impl::StablehloCanonicalizeDynamismPassBase< ++ StablehloCanonicalizeDynamismPass> { ++ using StablehloCanonicalizeDynamismPassBase:: ++ StablehloCanonicalizeDynamismPassBase; ++ ++ void runOnOperation() override { ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 2; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::AnyOp; ++ ++ RewritePatternSet patterns(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + patterns.add(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); - patterns.add( - &getContext()); - patterns.add(&getContext()); -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -43,6 +43,7 @@ - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "stablehlo/dialect/Base.h" - #include "stablehlo/dialect/ChloOps.h" -+#include "stablehlo/dialect/ExperimentalOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/dialect/TypeInference.h" - #include "stablehlo/transforms/Passes.h" -@@ -844,12 +845,97 @@ - } - }; - ++ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), ++ config))) { ++ return signalPassFailure(); ++ } ++ } ++}; ++ ++} // namespace ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +@@ -0,0 +1,178 @@ ++/* Copyright 2022 The StableHLO Authors. ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "stablehlo/transforms/StablehloRefineShapes.h" ++ ++#include ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++#include "stablehlo/dialect/Base.h" ++#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/dialect/TypeInference.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DEF_STABLEHLOREFINESHAPESPASS ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++namespace { ++ +struct RefineDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -1660,15 +2511,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + return refineReturnTypes(rewriter, op, inferredReturnTypes); + } +}; -+ - struct RefineDynamicReshapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DynamicReshapeOp op, - PatternRewriter& rewriter) const override { - return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; + +struct RefineDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { @@ -1710,18 +2552,908 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + + outputShape[operandType.getRank() - 1] = k[0]; + return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); - } - }; - -@@ -1181,7 +1267,10 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++ } ++}; ++ ++struct RefineTopKOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getTopKOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ TopKOpAdaptor op = *maybeOp; ++ ++ auto operandType = op.getOperand().getType().cast(); ++ SmallVector outputShape(operandType.getShape()); ++ outputShape.back() = op.getK(); ++ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); ++ } ++}; ++ ++struct StablehloRefineShapesPass ++ : public impl::StablehloRefineShapesPassBase { ++ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; ++ ++ void runOnOperation() override { ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); ++ ++ // The algorithm behind this pass consists of a single traversal of the ++ // function. This is sufficient because we only support one function per ++ // program at the moment. ++ // TODO(#1048): Find out why .maxIterations = 1 no longer works. ++ // There have been recent refactors to applyPatternsAndFoldGreedily ++ // upstream, and that might be the reason. ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 2; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::AnyOp; ++ ++ RewritePatternSet patterns(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + patterns.add(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ if (failed( ++ applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { ++ return signalPassFailure(); ++ } ++ } ++}; ++ ++} // namespace ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/infer_chlo.mlir b/stablehlo/stablehlo/tests/infer_chlo.mlir +--- stablehlo/stablehlo/tests/infer_chlo.mlir ++++ stablehlo/stablehlo/tests/infer_chlo.mlir +@@ -120,10 +120,10 @@ + // ----- + // CHECK-LABEL: @broadcast_select_reify + func.func @broadcast_select_reify(%arg0: tensor<2xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1xindex> { +- // CHECK: %0 = shape.const_shape [2] : tensor<1xindex> ++ // CHECK: %0 = shape.shape_of %arg0 : tensor<2xi1> -> tensor<1xindex> + // CHECK-NEXT: %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> + // CHECK-NEXT: %2 = shape.shape_of %arg2 : tensor -> tensor<1xindex> +- // CHECK-NEXT: %3 = shape.broadcast %1, %2, %0 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> ++ // CHECK-NEXT: %3 = shape.broadcast %0, %1, %2 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor, tensor) -> tensor + %1 = "hlo_test_infer.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + return %1: tensor<1xindex> +diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h +--- stablehlo/stablehlo/transforms/Passes.h ++++ stablehlo/stablehlo/transforms/Passes.h +@@ -18,9 +18,12 @@ + + #include + ++#include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/Dialect/Quant/QuantOps.h" + #include "mlir/Dialect/Shape/IR/Shape.h" ++#include "mlir/IR/BuiltinOps.h" + #include "mlir/Pass/Pass.h" ++#include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/DialectConversion.h" + + namespace mlir { +@@ -34,6 +37,14 @@ + #define GEN_PASS_DECL_VHLOTOVERSIONPASS + #define GEN_PASS_REGISTRATION + #include "stablehlo/transforms/Passes.h.inc" ++ ++// Populates --stablehlo-canonicalize-dynamism patterns. ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); ++ ++// Populates --stablehlo-refine-shapes patterns. ++void populateStablehloRefineShapesPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); + + // Populates StableHLO ops to VHLO ops rewriting patterns. + void populateStablehloToVhloPatterns(RewritePatternSet *patterns, +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -307,16 +307,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add( +- &getContext()); +- patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); +@@ -325,5 +316,19 @@ + }; + + } // namespace ++ ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -11,6 +11,8 @@ + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ ++ ++#include "stablehlo/transforms/StablehloRefineShapes.h" + + #include + #include +@@ -53,6 +55,193 @@ + #define GEN_PASS_DEF_STABLEHLOREFINESHAPESPASS + #include "stablehlo/transforms/Passes.h.inc" + ++LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, ++ ValueRange values, TypeRange types) { ++ if (values.size() != types.size()) ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "refineValues failed for " << types << ": expected " ++ << values.size() << " types, got " << types.size(); ++ }); ++ ++ // Check whether `types` contain any new information with respect to existing ++ // return types. Even if just a single dimension size out of an entire tensor ++ // type got updated, using `inferMostSpecificType` ensures that we don't ++ // miss that. ++ bool needsRefinement = false; ++ SmallVector refinedTypes; ++ for (auto it : llvm::zip(values.getTypes(), types)) { ++ // Cannot use structured bindings to simplify this because capturing ++ // structured bindings in a lambda is a C++ 20 extension. ++ auto currentType = std::get<0>(it); ++ auto refinement = std::get<1>(it); ++ auto refinedType = hlo::inferMostSpecificType( ++ /*location=*/{}, {currentType, refinement}); ++ if (failed(refinedType)) ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "inferMostSpecificType failed for " << currentType << " and " ++ << refinement; ++ }); ++ refinedTypes.push_back(*refinedType); ++ needsRefinement |= (currentType != *refinedType); ++ } ++ if (!needsRefinement) ++ return rewriter.notifyMatchFailure(op, "doesn't need refinement"); ++ ++ for (auto it : llvm::zip(values, refinedTypes)) { ++ // Cannot use structured bindings to simplify this because capturing ++ // structured bindings in a lambda is a C++ 20 extension. ++ auto value = std::get<0>(it); ++ auto refinedType = std::get<1>(it); ++ if (value.getType() == refinedType) continue; ++ ++ // Check whether the users of this value are ready for the type of the ++ // value to be refined. ++ for (Operation* user : value.getUsers()) { ++ // CHLO and StableHLO ops are designed to support type refinements of ++ // their operands and results. Any operand type in these ops can change ++ // within what's supported by `inferMostSpecificType` without breaking ++ // verification of the op. ++ if (isa(user->getDialect())) ++ continue; ++ ++ // Simply changing operand type of `func.return` won't work because ++ // that won't update the FunctionType of the enclosing `func.func`. ++ // Nonetheless, we still want to support these ops because they are widely ++ // used in StableHLO programs (although the plan of record is to replace ++ // `func.return` ops in StableHLO programs with `stablehlo.return`: ++ // https://github.com/openxla/stablehlo/issues/425). ++ if (isa(user)) continue; ++ ++ // Unlike in TensorFlow's type inference pass, here we work only with ++ // allowlisted ops to focus our support on well-defined semantics of ++ // StableHLO programs. ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "unsupported refinement: tried to refine " << value.getType() ++ << " to " << refinedType << " for user " << user; ++ }); ++ } ++ ++ // Happy path: simply call setType here because most of our users are ++ // fine with that. ++ auto unrefinedType = value.getType(); ++ value.setType(refinedType); ++ ++ // Special case: for `func.return`, guard the refinement with a cast ++ // and leave propagation of the refined return type to a dedicated pattern. ++ auto isFuncReturn = [](OpOperand& use) -> bool { ++ return isa(use.getOwner()); ++ }; ++ if (llvm::none_of(value.getUses(), isFuncReturn)) continue; ++ rewriter.setInsertionPointAfter(op); ++ auto castToUnrefinedType = rewriter.create( ++ op->getLoc(), unrefinedType, value); ++ value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); ++ } ++ ++ return success(); ++} ++ ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef types) { ++ if (failed(refineValues(rewriter, op, op->getResults(), types))) ++ return failure(); ++ ++ // This `replaceOpWithIf` call doesn't actually change the IR, but ++ // it does ask the rewriter to visit all the users of this op. There is no ++ // upstream API to achieve this directly, but if it's introduced in the ++ // future, we could use it here. ++ rewriter.replaceOpWithIf(op, op->getResults(), ++ [](OpOperand& use) { return false; }); ++ return success(); ++} ++ ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef refinements) { ++ SmallVector flattenedTypes; ++ hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); ++ auto flattenedSize = flattenedTypes.size(); ++ if (flattenedSize != refinements.size()) ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "refineReturnTypes failed: expected " << flattenedSize ++ << " refinements, got " << refinements.size(); ++ }); ++ ++ SmallVector flattenedRefinedTypes; ++ for (auto it : llvm::zip(flattenedTypes, refinements)) { ++ // Cannot use structured bindings to simplify this because capturing ++ // structured bindings in a lambda is a C++ 20 extension. ++ ShapedType currentType = std::get<0>(it).dyn_cast(); ++ ShapedTypeComponents refinement = std::get<1>(it); ++ auto failWithReason = [&](StringRef reason) { ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "refineTypes failed: refining " << currentType ++ << "with refinement: {"; ++ if (refinement.hasRank()) { ++ diag << "shape = [" << refinement.getDims() << "]"; ++ if (refinement.getAttribute()) ++ diag << "attribute = " << refinement.getAttribute(); ++ } else { ++ diag << "hasRank = false"; ++ } ++ diag << ", elementType = " << refinement.getElementType(); ++ diag << "} failed: " << reason; ++ }); ++ }; ++ ++ // If the current type is not a shaped type, then the refinement must ++ // be completely empty. ++ if (!currentType) { ++ if (refinement.hasRank() || refinement.getElementType() || ++ refinement.getAttribute()) ++ return failWithReason("unsupported refinement"); ++ flattenedRefinedTypes.push_back(currentType); ++ continue; ++ } ++ ++ // If the refinement has an element type, then it must be the same as ++ // the current element type. ++ Type currentElementType = currentType.getElementType(); ++ if (refinement.getElementType() && ++ currentElementType != refinement.getElementType()) ++ return failWithReason("expected compatible element types"); ++ ++ // If neither the current type nor the refinement are ranked, then there's ++ // nothing to refine, and we return the current type. ++ bool hasRank = currentType.hasRank() || refinement.hasRank(); ++ if (!hasRank) { ++ flattenedRefinedTypes.push_back(currentType); ++ continue; ++ } ++ ++ // If either the current type or the refinement have encodings, then ++ // we fail. Encodings are left for future work. ++ Attribute currentEncoding = nullptr; ++ if (auto currentRankedType = currentType.dyn_cast()) { ++ currentEncoding = currentRankedType.getEncoding(); ++ } ++ Attribute refinedEncoding = refinement.getAttribute(); ++ if (currentEncoding || refinedEncoding) ++ return failWithReason("expected compatible encodings"); ++ ++ // If both the current type and the refinement have shapes, use the shape ++ // from the refinement. Otherwise, pick whatever is available. ++ // Make sure that the resulting type is compatible with the current type ++ // to avoid creating invalid code. ++ auto refinedShape = ++ refinement.hasRank() ? refinement.getDims() : currentType.getShape(); ++ auto refinedType = RankedTensorType::get(refinedShape, currentElementType); ++ if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) ++ return failWithReason("expected compatible shapes"); ++ flattenedRefinedTypes.push_back(refinedType); ++ } ++ ++ SmallVector refinedTypes; ++ if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), ++ flattenedRefinedTypes, refinedTypes))) ++ return failure(); ++ return refineReturnTypes(rewriter, op, refinedTypes); ++} ++ + namespace { + + // DenseElementsAttr can be constructed from ArrayRef but not from +@@ -422,245 +611,6 @@ + // StableHLO-specific extension to refine return types based on potentially + // refined operands. + +-// Refines the values using the given types. +-// Tricky implementation details: +-// 1) Need to support partial shape refinements, e.g. if just a single +-// dimension size out of an entire tensor type got refined. This is done +-// via inferMostSpecificType. +-// 2) Need to signal propagation of the refined shapes across the +-// StableHLO program. Different callers of this function have different +-// propagation needs, so this function doesn't signal anything on its own +-// and leaves that to the callers. +-LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, +- ValueRange values, TypeRange types) { +- if (values.size() != types.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineValues failed for " << types << ": expected " +- << values.size() << " types, got " << types.size(); +- }); +- +- // Check whether `types` contain any new information with respect to existing +- // return types. Even if just a single dimension size out of an entire tensor +- // type got updated, using `inferMostSpecificType` ensures that we don't +- // miss that. +- bool needsRefinement = false; +- SmallVector refinedTypes; +- for (auto it : llvm::zip(values.getTypes(), types)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto currentType = std::get<0>(it); +- auto refinement = std::get<1>(it); +- auto refinedType = hlo::inferMostSpecificType( +- /*location=*/{}, {currentType, refinement}); +- if (failed(refinedType)) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "inferMostSpecificType failed for " << currentType << " and " +- << refinement; +- }); +- refinedTypes.push_back(*refinedType); +- needsRefinement |= (currentType != *refinedType); +- } +- if (!needsRefinement) +- return rewriter.notifyMatchFailure(op, "doesn't need refinement"); +- +- for (auto it : llvm::zip(values, refinedTypes)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto value = std::get<0>(it); +- auto refinedType = std::get<1>(it); +- if (value.getType() == refinedType) continue; +- +- // Check whether the users of this value are ready for the type of the +- // value to be refined. +- for (Operation* user : value.getUsers()) { +- // CHLO and StableHLO ops are designed to support type refinements of +- // their operands and results. Any operand type in these ops can change +- // within what's supported by `inferMostSpecificType` without breaking +- // verification of the op. +- if (isa(user->getDialect())) +- continue; +- +- // Simply changing operand type of `func.return` won't work because +- // that won't update the FunctionType of the enclosing `func.func`. +- // Nonetheless, we still want to support these ops because they are widely +- // used in StableHLO programs (although the plan of record is to replace +- // `func.return` ops in StableHLO programs with `stablehlo.return`: +- // https://github.com/openxla/stablehlo/issues/425). +- if (isa(user)) continue; +- +- // Unlike in TensorFlow's type inference pass, here we work only with +- // allowlisted ops to focus our support on well-defined semantics of +- // StableHLO programs. +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "unsupported refinement: tried to refine " << value.getType() +- << " to " << refinedType << " for user " << user; +- }); +- } +- +- // Happy path: simply call setType here because most of our users are +- // fine with that. +- auto unrefinedType = value.getType(); +- value.setType(refinedType); +- +- // Special case: for `func.return`, guard the refinement with a cast +- // and leave propagation of the refined return type to a dedicated pattern. +- auto isFuncReturn = [](OpOperand& use) -> bool { +- return isa(use.getOwner()); +- }; +- if (llvm::none_of(value.getUses(), isFuncReturn)) continue; +- rewriter.setInsertionPointAfter(op); +- auto castToUnrefinedType = rewriter.create( +- op->getLoc(), unrefinedType, value); +- value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); +- } +- +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef types) { +- if (failed(refineValues(rewriter, op, op->getResults(), types))) +- return failure(); +- +- // This `replaceOpWithIf` call doesn't actually change the IR, but +- // it does ask the rewriter to visit all the users of this op. There is no +- // upstream API to achieve this directly, but if it's introduced in the +- // future, we could use it here. +- rewriter.replaceOpWithIf(op, op->getResults(), +- [](OpOperand& use) { return false; }); +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// Tricky implementation details: +-// 1) `types` can include non-shaped types. If there are tuple types, +-// then they are first flattened into non-tuple types using in-order +-// traversal, and only then we apply the refinements. If there are other +-// types, then the corresponding refinements must be completely empty. +-// 2) Encodings are not supported. In principle, TypeExtensions should be +-// supportable, but this needs careful thinking through. Given that no one +-// asked for support for bounded dynamism in this pass yet, this is left +-// for future work. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef refinements) { +- SmallVector flattenedTypes; +- hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); +- auto flattenedSize = flattenedTypes.size(); +- if (flattenedSize != refinements.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineReturnTypes failed: expected " << flattenedSize +- << " refinements, got " << refinements.size(); +- }); +- +- SmallVector flattenedRefinedTypes; +- for (auto it : llvm::zip(flattenedTypes, refinements)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- ShapedType currentType = std::get<0>(it).dyn_cast(); +- ShapedTypeComponents refinement = std::get<1>(it); +- auto failWithReason = [&](StringRef reason) { +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineTypes failed: refining " << currentType +- << "with refinement: {"; +- if (refinement.hasRank()) { +- diag << "shape = [" << refinement.getDims() << "]"; +- if (refinement.getAttribute()) +- diag << "attribute = " << refinement.getAttribute(); +- } else { +- diag << "hasRank = false"; +- } +- diag << ", elementType = " << refinement.getElementType(); +- diag << "} failed: " << reason; +- }); +- }; +- +- // If the current type is not a shaped type, then the refinement must +- // be completely empty. +- if (!currentType) { +- if (refinement.hasRank() || refinement.getElementType() || +- refinement.getAttribute()) +- return failWithReason("unsupported refinement"); +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If the refinement has an element type, then it must be the same as +- // the current element type. +- Type currentElementType = currentType.getElementType(); +- if (refinement.getElementType() && +- currentElementType != refinement.getElementType()) +- return failWithReason("expected compatible element types"); +- +- // If neither the current type nor the refinement are ranked, then there's +- // nothing to refine, and we return the current type. +- bool hasRank = currentType.hasRank() || refinement.hasRank(); +- if (!hasRank) { +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If either the current type or the refinement have encodings, then +- // we fail. Encodings are left for future work. +- Attribute currentEncoding = nullptr; +- if (auto currentRankedType = currentType.dyn_cast()) { +- currentEncoding = currentRankedType.getEncoding(); +- } +- Attribute refinedEncoding = refinement.getAttribute(); +- if (currentEncoding || refinedEncoding) +- return failWithReason("expected compatible encodings"); +- +- // If both the current type and the refinement have shapes, use the shape +- // from the refinement. Otherwise, pick whatever is available. +- // Make sure that the resulting type is compatible with the current type +- // to avoid creating invalid code. +- auto refinedShape = +- refinement.hasRank() ? refinement.getDims() : currentType.getShape(); +- auto refinedType = RankedTensorType::get(refinedShape, currentElementType); +- if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) +- return failWithReason("expected compatible shapes"); +- flattenedRefinedTypes.push_back(refinedType); +- } +- +- SmallVector refinedTypes; +- if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), +- flattenedRefinedTypes, refinedTypes))) +- return failure(); +- return refineReturnTypes(rewriter, op, refinedTypes); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- ArrayRef shape) { +- return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- Value shapeValue) { +- // At the moment, we only support refining return types using fully static +- // shape values which serves the current use cases well. +- // Support for partially static shape values is left for future work. +- SmallVector shape; +- if (failed(hlo::matchInts(shapeValue, shape))) +- return rewriter.notifyMatchFailure(op, "expected constant output shape"); +- return refineReturnShape(rewriter, op, shape); +-} +- + struct RefineAllGatherOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AllGatherOp op, +@@ -1115,39 +1065,8 @@ + using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + + void runOnOperation() override { +- // Only one function per module is supported at the moment to avoid the need +- // to think about iterative type inference algorithms. +- // Current use cases are served well by inlining multiple functions into +- // a single function, so we leave native support for multiple functions to +- // future work. +- // To enable modules that contain CustomCallOp::called_computations, +- // we allow multiple functions, in which case we only refine the main +- // function called "main", assuming that the called computations will have +- // static shapes. Lifting this assumption and expanding refinement to +- // multiple functions is left for future work. +- ModuleOp module = getOperation(); +- auto funcs = llvm::to_vector(module.getOps()); +- if (funcs.empty()) return; +- func::FuncOp func; +- if (funcs.size() == 1) { +- func = funcs[0]; +- } else { +- func = module.lookupSymbol("main"); +- } +- if (!func) { +- module.emitOpError() +- << "must have no more than one function or a `main`" +- << " function to clearly identify which function will be refined"; +- return signalPassFailure(); +- } +- +- // Similarly, only one block per function is supported at the moment. +- // At the StableHLO level, functions are expected to only have one block, +- // so supporting more is out of scope for this pass. +- if (!func.getRegion().hasOneBlock()) { +- func.emitOpError() << "must have exactly one block"; +- return signalPassFailure(); +- } ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + + // The algorithm behind this pass consists of a single traversal of the + // function. This is sufficient because we only support one function per +@@ -1163,44 +1082,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + if (failed( + applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { + return signalPassFailure(); +@@ -1209,5 +1091,86 @@ + }; + + } // namespace ++ ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) { ++ // Only one function per module is supported at the moment to avoid the need ++ // to think about iterative type inference algorithms. ++ // Current use cases are served well by inlining multiple functions into ++ // a single function, so we leave native support for multiple functions to ++ // future work. ++ // To enable modules that contain CustomCallOp::called_computations, ++ // we allow multiple functions, in which case we only refine the main ++ // function called "main", assuming that the called computations will have ++ // static shapes. Lifting this assumption and expanding refinement to ++ // multiple functions is left for future work. ++ auto funcs = llvm::to_vector(module.getOps()); ++ if (funcs.empty()) return nullptr; ++ ++ func::FuncOp result; ++ if (funcs.size() == 1) { ++ result = funcs[0]; ++ } else { ++ result = module.lookupSymbol("main"); ++ } ++ if (!result) { ++ module.emitOpError() ++ << "must have no more than one function or a `main`" ++ << " function to clearly identify which function will be refined"; ++ return nullptr; ++ } ++ ++ // Similarly, only one block per function is supported at the moment. ++ // At the StableHLO level, functions are expected to only have one block, ++ // so supporting more is out of scope for this pass. ++ if (!result.getRegion().hasOneBlock()) { ++ result.emitOpError() << "must have exactly one block"; ++ return nullptr; ++ } ++ ++ return result; ++} ++ ++void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h +@@ -0,0 +1,102 @@ ++/* Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/BuiltinOps.h" ++#include "mlir/IR/Operation.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/Types.h" ++#include "mlir/IR/Value.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "stablehlo/dialect/Base.h" ++ ++namespace mlir { ++namespace stablehlo { ++ ++// Gets a FuncOp that --stablehlo-refine-shapes will run on. ++// Returns a nullptr and emits appropriate errors if such a function cannot ++// be obtained from the module. ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module); ++ ++// Refines the values using the given types. ++// Tricky implementation details: ++// 1) Need to support partial shape refinements, e.g. if just a single ++// dimension size out of an entire tensor type got refined. This is done ++// via inferMostSpecificType. ++// 2) Need to signal propagation of the refined shapes across the ++// StableHLO program. Different callers of this function have different ++// propagation needs, so this function doesn't signal anything on its own ++// and leaves that to the callers. ++LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, ++ ValueRange values, TypeRange types); ++ ++// Refines the return types of the given operation using the given types. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef types); ++ ++// Refines the return types of the given operation using the given types. ++// Tricky implementation details: ++// 1) `types` can include non-shaped types. If there are tuple types, ++// then they are first flattened into non-tuple types using in-order ++// traversal, and only then we apply the refinements. If there are other ++// types, then the corresponding refinements must be completely empty. ++// 2) Encodings are not supported. In principle, TypeExtensions should be ++// supportable, but this needs careful thinking through. Given that no one ++// asked for support for bounded dynamism in this pass yet, this is left ++// for future work. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef refinements); ++ ++// Refines the return type of the given operation using the given shape. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++template ++LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, ++ ArrayRef shape) { ++ return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); ++} ++ ++// Refines the return type of the given operation using the given shape. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++template ++LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, ++ Value shapeValue) { ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ SmallVector shape; ++ if (failed(hlo::matchInts(shapeValue, shape))) ++ return rewriter.notifyMatchFailure(op, "expected constant output shape"); ++ return refineReturnShape(rewriter, op, shape); ++} ++ ++} // namespace stablehlo ++} // namespace mlir ++ ++#endif // STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H +diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +@@ -430,9 +430,20 @@ + SmallVector& stablehloAttrs) { + auto tensorAttr = dyn_cast(vhloAttr); + if (!tensorAttr) return specialFailure(); +- ArrayRef data( +- reinterpret_cast(tensorAttr.getData().data()), +- tensorAttr.getData().size() / sizeof(int64_t)); ++ ++ auto data = ArrayRef( ++ reinterpret_cast(tensorAttr.getData().data()), ++ tensorAttr.getData().size() / sizeof(int64_t)) ++ .vec(); ++ ++ // Handle splats ++ if (data.size() == 1) { ++ auto tensorType = tensorAttr.getType().dyn_cast(); ++ if (!tensorType || (tensorType.getShape().size() != 1)) ++ return specialFailure(); ++ auto size = tensorType.getShape()[0]; ++ data.resize(size, data[0]); ++ } + + stablehloAttrs.emplace_back( + vhloName, DenseI64ArrayAttr::get(vhloAttr.getContext(), data)); diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 8d7054dda8b2c0..f175093e925b74 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "04291aea6b50d9573e6f4de184938d83b9564cd0" - STABLEHLO_SHA256 = "2f57b2cb8eeadebe8430e294f88919b392cf472c62fdd40d4713680b283d64e5" + STABLEHLO_COMMIT = "ab709fe48de88c67717abfbd7ef17425eb95ddaf" + STABLEHLO_SHA256 = "a469ecc3d6747f9effdc1c7813568953dd1dc30070ca8f4f6f8a4d405e8c687e" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 6dd0e178ec09b7..9fca8c020bf276 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "e45cd275068c87cbd1d42d0dc89475d72798a9e8" - TFRT_SHA256 = "dd4a1440fdc8bf142c5ac00bd6227e41999a0912b2f847e932b57307f97138dd" + TFRT_COMMIT = "dbd8da33ab49ed8aa5f08ebe85bacb91341f5d61" + TFRT_SHA256 = "b95b1d17eb2e28ee0f00ae672c7377767a17e7dadde169b335aa481bb07883c7" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/cl577369732.patch b/third_party/triton/cl577369732.patch deleted file mode 100644 index e63b9f3804974b..00000000000000 --- a/third_party/triton/cl577369732.patch +++ /dev/null @@ -1,116 +0,0 @@ -==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#19 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -759,7 +759,7 @@ - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { -- OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &operand = *forOp.getTiedLoopInit(arg); - setValueMapping(arg, operand.get(), 0); - } - -==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#10 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -188,7 +188,7 @@ - auto getIncomingOp = [this](Value v) -> Value { - if (auto arg = v.dyn_cast()) - if (arg.getOwner()->getParentOp() == forOp.getOperation()) -- return forOp.getOpOperandForRegionIterArg(arg).get(); -+ return forOp.getTiedLoopInit(arg)->get(); - return Value(); - }; - -@@ -298,10 +298,10 @@ - Operation *firstDot = builder.clone(*dot, mapping); - if (Value a = operand2headPrefetch.lookup(dot.getA())) - firstDot->setOperand( -- 0, newForOp.getRegionIterArgForOpOperand(*a.use_begin())); -+ 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); - if (Value b = operand2headPrefetch.lookup(dot.getB())) - firstDot->setOperand( -- 1, newForOp.getRegionIterArgForOpOperand(*b.use_begin())); -+ 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); - - // remaining part - int64_t kOff = prefetchWidth; -==== triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#18 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -245,7 +245,7 @@ - for (OpOperand &use : value.getUses()) { - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { -- Value arg = forOp.getRegionIterArgForOpOperand(use); -+ Value arg = forOp.getTiedLoopRegionIterArg(&use); - Value result = forOp.getResultForOpOperand(use); - setEncoding({arg, result}, info, changed, user); - continue; -@@ -767,7 +767,7 @@ - SmallVector newOperands; - for (auto arg : forOp.getRegionIterArgs()) { - if (slice.count(arg)) { -- OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &initVal = *forOp.getTiedLoopInit(arg); - argMapping.push_back(std::make_pair( - forOp.getResultForOpOperand(initVal).getResultNumber(), - forOp.getInitArgs().size() + newOperands.size())); -==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#16 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -430,10 +430,10 @@ - Block *block = blockArg.getOwner(); - Operation *parentOp = block->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { -- OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg); -+ OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); - Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( - blockArg.getArgNumber() - forOp.getNumInductionVars()); -- queue.push_back({initOperand.get(), encoding}); -+ queue.push_back({initOperand->get(), encoding}); - queue.push_back({yieldOperand, encoding}); - continue; - } -==== triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp#1 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -88,9 +88,8 @@ - auto parentOp = blockArg.getOwner()->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- if (failed(getDependentPointers( -- forOp.getOpOperandForRegionIterArg(blockArg).get(), -- dependentSet, processedSet))) -+ if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(), -+ dependentSet, processedSet))) - return failure(); - - unsigned operandIdx = -@@ -383,7 +382,7 @@ - if (failed(addControlOperandsForForOp(forOp))) - return failure(); - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get(); -+ Value operand = forOp.getTiedLoopInit(blockArg)->get(); - if (failed(tryInsertAndPropagate(operand))) - return failure(); - -==== triton/test/lib/Analysis/TestAlias.cpp#5 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/test/lib/Analysis/TestAlias.cpp ==== -# action=edit type=text ---- triton/test/lib/Analysis/TestAlias.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/test/lib/Analysis/TestAlias.cpp 2023-10-27 20:17:47.000000000 -0700 -@@ -87,7 +87,7 @@ - } - if (auto forOp = dyn_cast(op)) { - for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { -- auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); -+ auto operand = forOp.getTiedLoopInit(arg.value())->get(); - auto opNames = getAllocOpNames(operand); - auto argName = getValueOperandName(arg.value(), state); - print(argName, opNames, os); diff --git a/third_party/triton/cl577379396.patch b/third_party/triton/cl577379396.patch deleted file mode 100644 index ee569f9b8f55c3..00000000000000 --- a/third_party/triton/cl577379396.patch +++ /dev/null @@ -1,33 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ---- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -@@ -246,7 +246,7 @@ SmallVector LayoutPropagation::pr - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { - Value arg = forOp.getTiedLoopRegionIterArg(&use); -- Value result = forOp.getResultForOpOperand(use); -+ Value result = forOp.getTiedLoopResult(&use); - setEncoding({arg, result}, info, changed, user); - continue; - } -@@ -769,7 +769,7 @@ static void rewriteSlice(SetVector()) { - auto result = value.cast(); -- OpOperand &forOperand = nestedFor.getOpOperandForResult(result); -+ OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); - markLive(forOperand.get()); - auto nestedYieldOp = - cast(nestedFor.getBody()->getTerminator()); diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index c0c6207f85da73..b864617b503f3e 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl578837341" - TRITON_SHA256 = "0d8112bb31d48b5beadbfc2e13c52770a95d3759b312b15cf26dd72e71410568" + TRITON_COMMIT = "cl588045313" + TRITON_SHA256 = "14cb6ddccc3139b2e8d77af08bb232eb06536d5c715c4bbc720a752af40ba2dc" tf_http_archive( name = "triton", @@ -15,7 +15,7 @@ def repo(): urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. patch_file = [ - "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", + "//third_party/triton:cl568176943.patch", ], ) diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index e9fc2d4eb20a55..9de6b6e0c2bd54 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -526,34 +526,9 @@ build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_c build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=cuda +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true -build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true -build:rbe_linux_cuda_nvcc --config=tensorrt -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" -build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9" -build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3" -# These you may need to change for your own GCP project. -common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl" -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -692,19 +667,39 @@ build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda build:release_cpu_macos --config=avx_linux test:release_cpu_macos --config=release_base -# Build configs for macOS ARM CPUs +# Base build configs for macOS +build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer +build:release_macos_base --define=no_nccl_support=true --output_filter=^$ + +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Build configs for macOS Arm64 +build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 -# Set DEVELOPER_DIR to select a version of Xcode. -build:release_macos_arm64 --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer -build:release_macos_arm64 --define=no_nccl_support=true -# Suppress all warning messages -build:release_macos_arm64 --output_filter=^$ -# Disable MKL build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0 # Target Moneterey as the minimum compatible OS version build:release_macos_arm64 --macos_minimum_os=12.0 build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0 +# Base test configs for macOS +test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS +test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors +test:release_macos_base --build_tests_only --keep_going +test:release_macos_base --flaky_test_attempts=3 + +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + +# Test configs for macOS Arm64 +test:release_macos_arm64 --config=release_macos_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true @@ -723,10 +718,14 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs -# push to the cache. +# push to the cache. For macOS, use --config=tf_public_macos_cache build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials +# Public cache for macOS builds +build:tf_public_macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to TF's CI system. +build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_local_results=true --google_default_credentials # END TF CACHE HELPER OPTIONS # BEGIN TF TEST SUITE OPTIONS @@ -743,22 +742,27 @@ build:linux_libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow. test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --test_lang_filters=py -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -766,21 +770,53 @@ test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorf test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA PYCPP: test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +# CROSS-COMPILE ARM64 PYCPP +test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +# Tests that fail only when cross-compiled +test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # END TF TEST SUITE OPTIONS + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_linux_arm64 --host_cpu=k8 +build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE configs +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_base +build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local +test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors +test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only +# END LINUX AARCH64 CROSS-COMPILE CONFIGS diff --git a/third_party/xla/.github/workflows/trusted_partners.js b/third_party/xla/.github/workflows/trusted_partners.js index fcb1551059cc73..75a1ff082592b7 100644 --- a/third_party/xla/.github/workflows/trusted_partners.js +++ b/third_party/xla/.github/workflows/trusted_partners.js @@ -53,7 +53,7 @@ const get_email_domain = async ({github, username}) => { const filter_action = async ({github, context, domain}) => { const labels = ['kokoro:force-run']; - let assignees = ['radhakrishnaba', 'xla-rotation']; + let assignees = ['kamaljeeti', 'xla-rotation']; const title = context.payload.pull_request && context.payload.pull_request.title; const lowercased_title = (title || '').toLowerCase(); diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index 417b515a4b4898..4cfd6d12426b87 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -37,11 +37,12 @@ prelude() { if is_linux_gpu_job ; then export JAX_CUDA_VERSION=12 export JAX_CUDNN_VERSION=8.9 - nvidia-smi + setup_env_vars_py39 + else + setup_env_vars_py312 fi - setup_env_vars_py312 cd "${KOKORO_ARTIFACTS_DIR}" use_local_or_install_python @@ -50,52 +51,49 @@ prelude() { # Install bazel update_bazel_linux - chmod +x "${KOKORO_GFILE_DIR}/bazel_wrapper.py" cd jax } build_and_test_on_rbe_cpu() { # Run the tests. - "${KOKORO_GFILE_DIR}/bazel_wrapper.py" \ + bazel \ test \ --verbose_failures=true \ --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \ --config=avx_posix \ - --config=tpu \ --config=mkl_open_source_only \ - --config="$NOCUDA_RBE_CONFIG_NAME" \ + --config="rbe_cpu_linux_py3.12" \ --config=tensorflow_testing_rbe_linux \ --test_env=JAX_NUM_GENERATED_CASES=25 \ - //tests:cpu_tests //tests:backend_independent_tests \ - --test_output=errors + --test_output=errors \ + -- //tests:cpu_tests //tests:backend_independent_tests } build_and_test_on_rbe_gpu() { # Runs non-multiaccelerator tests with one GPU apiece. # It appears --run_under needs an absolute path. - "${KOKORO_GFILE_DIR}/bazel_wrapper.py" \ + + bazel \ test \ --verbose_failures=true \ - //tests:gpu_tests //tests:backend_independent_tests \ --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \ --config=avx_posix \ --config=mkl_open_source_only \ - --config="$CUDA_RBE_CONFIG_NAME" \ + --config="rbe_linux_cuda12.2_nvcc_py3.9" \ + --config=tensorflow_testing_rbe_linux \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ --test_env=JAX_SKIP_SLOW_TESTS=1 \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ - --test_tag_filters=-multiaccelerator + --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow" \ + --test_tag_filters=-multiaccelerator \ + -- //tests:gpu_tests //tests:backend_independent_tests } # Generate a templated results file to make output accessible to everyone "$KOKORO_ARTIFACTS_DIR"/github/xla/.kokoro/generate_index_html.sh "$KOKORO_ARTIFACTS_DIR"/index.html -NOCUDA_RBE_CONFIG_NAME="rbe_cpu_linux_py312" -CUDA_RBE_CONFIG_NAME="rbe_linux_cuda12.2_nvcc_py3.12" - prelude if is_linux_gpu_job ; then diff --git a/third_party/xla/.kokoro/linux/build.sh b/third_party/xla/.kokoro/linux/build.sh index 49b10b04a899ca..635af61a6d3ed5 100644 --- a/third_party/xla/.kokoro/linux/build.sh +++ b/third_party/xla/.kokoro/linux/build.sh @@ -26,10 +26,6 @@ function is_linux_gpu_job() { [[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/linux/.*gpu.* ]] } -function is_use_nvcc() { - [[ -z "${USE_NVCC:-}" ]] || [[ "$USE_NVCC" == "true" ]] -} - # Pull the container (in case it was updated since the instance started) and # store its SHA in the Sponge log. docker pull "$DOCKER_IMAGE" @@ -54,11 +50,7 @@ if is_linux_gpu_job ; then TAGS_FILTER="$TAGS_FILTER,gpu,requires-gpu-nvidia,-no_gpu" ADDITIONAL_FLAGS="$ADDITIONAL_FLAGS --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute" RC_FILE="/usertools/gpu.bazelrc" - if is_use_nvcc ; then - RBE_CONFIG="rbe_linux_cuda_nvcc" - else - RBE_CONFIG="rbe_linux_cuda" - fi + RBE_CONFIG="rbe_linux_cuda_nvcc" echo "***NOTE: nvidia-smi lists the highest CUDA version the driver supports, which may be different than the version of CUDA actually used!!***" nvidia-smi else diff --git a/third_party/xla/build_tools/lint/BUILD b/third_party/xla/build_tools/lint/BUILD index 8ca1872bb1b064..0270b76421a545 100644 --- a/third_party/xla/build_tools/lint/BUILD +++ b/third_party/xla/build_tools/lint/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -load("//xla:pytype.default.bzl", "pytype_strict_library") +load("//xla:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_library") # Placeholder: load py_test package( @@ -34,6 +34,11 @@ pytype_strict_library( visibility = ["//visibility:public"], ) +pytype_strict_binary( + name = "generate_compile_commands", + srcs = ["generate_compile_commands.py"], +) + py_test( name = "check_contents_test", srcs = ["check_contents_test.py"], diff --git a/third_party/xla/build_tools/lint/check_contents.py b/third_party/xla/build_tools/lint/check_contents.py index 1649152148d1a4..5d09ec074b3b1e 100644 --- a/third_party/xla/build_tools/lint/check_contents.py +++ b/third_party/xla/build_tools/lint/check_contents.py @@ -22,7 +22,7 @@ import logging # Intended to run on vanilla Github Actions runner import re import sys -from typing import Iterable, Optional, Sequence +from typing import Iterable, Sequence from xla.build_tools.lint import diff_parser @@ -92,7 +92,7 @@ def check_diffs( hunks: Iterable[diff_parser.Hunk], *, prohibited_regex: str, - suppression_regex: Optional[str] = None, # TODO(ddunleavy): CI not on 3.10 + suppression_regex: str | None = None, ) -> list[RegexLocation]: """Checks FileDiffs for prohibited regexes. diff --git a/third_party/xla/build_tools/lint/generate_compile_commands.py b/third_party/xla/build_tools/lint/generate_compile_commands.py new file mode 100644 index 00000000000000..735fc53f8aa8a6 --- /dev/null +++ b/third_party/xla/build_tools/lint/generate_compile_commands.py @@ -0,0 +1,129 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +r"""Produces a `compile_commands.json` from the output of `bazel aquery`. + +This tool requires that a build has been completed for all targets in the +query (e.g., for the example usage below `bazel build //xla/...`). This is due +to generated files like proto headers and files generated via tablegen. So if +LSP or other tools get out of date, it may be necessary to rebuild or regenerate +`compile_commands.json`, or both. + +Example usage: + bazel aquery "mnemonic(CppCompile, //xla/...)" --output=jsonproto | \ + python3 build_tools/lint/generate_compile_commands.py +""" +import dataclasses +import json +import logging +import pathlib +import sys +from typing import Any + +_JSONDict = dict[Any, Any] # Approximates parsed JSON + +_DISALLOWED_ARGS = frozenset(["-fno-canonical-system-headers"]) +_XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent.parent.parent + + +@dataclasses.dataclass +class CompileCommand: + """Represents a compilation command with options on a specific file.""" + + file: str + arguments: list[str] + + @classmethod + def from_args_list(cls, args_list: list[str]) -> "CompileCommand": + """Alternative constructor which uses the args_list from `bazel aquery`. + + This collects arguments and the file being run on from the output of + `bazel aquery`. Also filters out arguments which break clang-tidy. + + Arguments: + args_list: List of arguments generated by `bazel aquery` + + Returns: + The corresponding ClangTidyCommand. + """ + cc_file = None + filtered_args = [] + + for arg in args_list: + if arg in _DISALLOWED_ARGS: + continue + + if arg.endswith(".cc"): + cc_file = arg + + filtered_args.append(arg) + + return cls(cc_file, filtered_args) + + def to_dumpable_json(self, directory: str) -> _JSONDict: + return { + "directory": directory, + "file": self.file, + "arguments": self.arguments, + } + + +def extract_compile_commands( + parsed_aquery_output: _JSONDict, +) -> list[CompileCommand]: + """Gathers compile commands to run from `bazel aquery` JSON output. + + Arguments: + parsed_aquery_output: Parsed JSON representing the output of `bazel aquery + --output=jsonproto`. + + Returns: + The list of CompileCommands that should be executed. + """ + actions = parsed_aquery_output["actions"] + + commands = [] + for action in actions: + command = CompileCommand.from_args_list(action["arguments"]) + commands.append(command) + return commands + + +def main(): + # Setup logging + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + # Setup external symlink if necessary so headers can be found in include paths + if not (external := _XLA_SRC_ROOT / "external").exists(): + logging.info("Symlinking `xla/bazel-xla/external` to `xla/external`") + external.symlink_to(_XLA_SRC_ROOT / "bazel-xla" / "external") + + logging.info("Reading `bazel aquery` output from stdin...") + parsed_aquery_output = json.loads(sys.stdin.read()) + + commands = extract_compile_commands(parsed_aquery_output) + + with (_XLA_SRC_ROOT / "compile_commands.json").open("w") as f: + json.dump( + [ + command.to_dumpable_json(directory=str(_XLA_SRC_ROOT)) + for command in commands + ], + f, + ) + + +if __name__ == "__main__": + main() diff --git a/third_party/xla/docs/_book.yaml b/third_party/xla/docs/_book.yaml new file mode 100644 index 00000000000000..a6030d45a9949f --- /dev/null +++ b/third_party/xla/docs/_book.yaml @@ -0,0 +1,47 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +- include: /api_docs/_upper_tabs_api.yaml +# Dropdown menu +- name: Resources + path: /resources + is_default: true + menu: + - include: /resources/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Overview + contents: + - heading: OpenXLA + - title: Overview + path: /xla + - title: XLA architecture + path: /xla/architecture + - title: Broadcasting semantics + path: /xla/broadcasting + - title: Develop a new backend for XLA + path: /xla/developing_new_backend + - title: Code Reviews Guide + path: /xla/code_reviews + - title: Operation semantics + path: /xla/operation_semantics + - title: Shapes and layout + path: /xla/shapes + - title: Aliasing + path: /xla/aliasing + - title: Tiled layout + path: /xla/tiled_layout + - title: Writing custom calls + path: /xla/custom_call + - heading: TensorFlow - XLA + - title: Known issues + path: /xla/known_issues + - title: Use AOT compilation + path: /xla/tfcompile + - title: XLA autoclustering + path: /xla/tutorials/autoclustering_xla + - title: Use XLA with tf.function + path: /xla/tutorials/jit_compile + +- include: /_upper_tabs_right.yaml diff --git a/third_party/xla/docs/async_ops.md b/third_party/xla/docs/async_ops.md new file mode 100644 index 00000000000000..889272eecc4411 --- /dev/null +++ b/third_party/xla/docs/async_ops.md @@ -0,0 +1,121 @@ +# Async HLO Instructions + +1. Adding async operations to HLO is cumbersome (i.e. `all-reduce-start` and + `all-reduce-done`). +2. The start and done split may be inadequate for some of the asynchronous use + cases. + +To target the first shortcoming, we propose to introduce one last set of new +asynchronous opcodes: `kAsyncStart`, `kAsyncUpdate`, and `kAsyncDone`. The idea +is to create a generic asynchronous opcode that can wrap any HLO instruction. +The actual operation that will be performed asynchronously will be encoded using +a called computation that only has the instruction as its root and any +parameters for inputs. The in-flight input/output buffer handling and aliasing +can then be shared for any asynchronous operation. The async-start instruction’s +output shape will then be a tuple of the input operands, output values, and any +intermediate state that is needed for the `async-update` or `async-done` +instructions. + +``` +%async_op { + %param0 = f32[64] parameter(0) + ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo” +} + +%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), + calls=%async_op +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start), + calls=%async_op +``` + +In the representation above, only `async-start` has a called computation since +it is trivial to find what the `async-done` does by following its operand to +find the corresponding `async-start` to find the called computation. + +Today both `async-start` and `async-done` have a called computation attribute, +but long term we plan to keep it only for `async-start`, since it is trivial +to find what the `async-done` does by following its operand to find the +corresponding `async-start` to find the called computation. + +> [!NOTE] +> Tracked as b/302594825 internally. + +Also note +that the first element in the output tuple of `async-start` aliases with the +operand, so the buffer stays alive until at least the async-done instruction. +Similarly, the second element aliases with the output of `async-done`, and the +third element is the context state that is used to keep track of the +asynchronous operation. This representation also supports multiple tensors in +the asynchronous operation input and/or output and the aliasing works the same +way: + +``` +%async_op { + %param0 = f32[64] parameter(0) + %param1 = f32[64] parameter(1) + ROOT %op = (f32[32], f32[32]) op(f32[64] %param0, f32[64] %param1), + op_specific_attr=”foo” +} + +%async-start = ((f32[64], f32[64]), (f32[32], f32[32]), s32[]) + async-start(f32[64] %operand0, f32[64] %operand1), + calls=%async_op +%async-done = (f32[32], f32[32]) async-done(%async-start) +``` + +In addition, the op can further be decomposed into zero or more `async-update` +steps that perform intermediate computations. The input/output aliasing works +the same way with the `async-update` instruction and each `async-start` and +`async-update` instructions must have one user that is either another +`async-update` or an `async-done`: + +``` +%async_op { + %param0 = f32[64] parameter(0) + ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo” +} + +%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand), + calls=%async_op +%async-update0 = (f32[64], f32[32], s32[]) async-update( + (f32[64], f32[32], s32[]) %async-start) +%async-update1 = (f32[64], f32[32], s32[]) async-update( + (f32[64], f32[32], s32[]) %async-update0) +%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-update1) + +``` + +## Syntax sugar + +Since having a separate computation to define the operation that will be +performed asynchronously is a bit cumbersome, we also propose a syntax sugar to +automatically print and parse asynchronous operations as if they are first-class +opcodes. The idea is to treat the “-start”, “-update”, and “-done” suffixes +specially by automatically creating the computation and instruction (without the +suffix) when parsing. For example, the code snippet above can be pretty-printed +to the following and the two can be parsed to the same representation: + +``` +%op-start = (f32[64], f32[32], s32[]) op-start(f32[64] %operand), + op_specific_attr=”foo” +%op-update0 = (f32[64], f32[32], s32[]) op-update( + (f32[64], f32[32], s32[]) %op-start), + op_specific_attr=”foo” +%op-update1 = (f32[64], f32[32], s32[]) op-update( + (f32[64], f32[32], s32[]) %op-update0), + op_specific_attr=”foo” +%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1), + op_specific_attr=”foo” + +``` + +In order not to create ambiguities, the verifier will not allow an operation to +be wrapped with async-start if we explicitly defined an opcode for that +operation with the “-start” and/or “-done” suffixes. This is also an escape +hatch in case we have any instructions that require HLO-level treatment that +doesn’t fit in the model described above (e.g. the aliasing input/output +buffers). So, initially, `copy-start`/`copy-done`, +`collective-permute-start`/`collective-permute-done` etc. will continue to use +their respective first-class opcodes instead of the new +`async-start`/`async-done` opcodes until we clean up the code to remove these +“-start”/”-done” opcodes. diff --git a/third_party/xla/docs/build_from_source.md b/third_party/xla/docs/build_from_source.md index 9c4cc0e401fd37..f5b2ded3c4cd4e 100644 --- a/third_party/xla/docs/build_from_source.md +++ b/third_party/xla/docs/build_from_source.md @@ -33,7 +33,7 @@ We recommend using a suitable docker container to build/test XLA, such as [TensorFlow's docker container](https://www.tensorflow.org/install/docker): ``` -docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash +docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/tensorflow:latest-gpu bash ``` Using a docker container you can build XLA with CPU support using the following commands: diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 9abb2546fa24ed..9de7578a5801a9 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -26,6 +26,8 @@ tools/toolchains/BUILD: tools/toolchains/clang6/BUILD: tools/toolchains/cpus/py/BUILD: tools/toolchains/cpus/py3/BUILD: +tools/toolchains/cross_compile/cc/BUILD: +tools/toolchains/cross_compile/config/BUILD: tools/toolchains/embedded/arm-linux/BUILD: tools/toolchains/java/BUILD: tools/toolchains/python/BUILD: diff --git a/third_party/xla/third_party/cutlass.BUILD b/third_party/xla/third_party/cutlass.BUILD new file mode 100644 index 00000000000000..923d2f044c395a --- /dev/null +++ b/third_party/xla/third_party/cutlass.BUILD @@ -0,0 +1,24 @@ +# Description: +# CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance +# matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT + +exports_files(["LICENSE.txt"]) + +filegroup( + name = "cutlass_header_files", + srcs = glob([ + "include/**", + ]), +) + +cc_library( + name = "cutlass", + hdrs = [":cutlass_header_files"], + strip_include_prefix = "/include", +) diff --git a/third_party/xla/third_party/gloo/BUILD b/third_party/xla/third_party/gloo/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/xla/third_party/gloo/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/xla/third_party/gloo/gloo.BUILD b/third_party/xla/third_party/gloo/gloo.BUILD new file mode 100644 index 00000000000000..e960fc518a7699 --- /dev/null +++ b/third_party/xla/third_party/gloo/gloo.BUILD @@ -0,0 +1,97 @@ +# Description: +# Gloo is a collective communications library + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +substitions = { + "@GLOO_VERSION_MAJOR@": "9999", + "@GLOO_VERSION_MINOR@": "0", + "@GLOO_VERSION_PATCH@": "0", + "#cmakedefine01 GLOO_USE_CUDA": "#define GLOO_USE_CUDA 0", + "#cmakedefine01 GLOO_USE_NCCL": "#define GLOO_USE_NCCL 0", + "#cmakedefine01 GLOO_USE_ROCM": "#define GLOO_USE_ROCM 0", + "#cmakedefine01 GLOO_USE_RCCL": "#define GLOO_USE_RCCL 0", + "#cmakedefine01 GLOO_USE_REDIS": "#define GLOO_USE_REDIS 0", + "#cmakedefine01 GLOO_USE_IBVERBS": "#define GLOO_USE_IBVERBS 0", + "#cmakedefine01 GLOO_USE_MPI": "#define GLOO_USE_MPI 0", + "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "#define GLOO_HAVE_TRANSPORT_TCP 1", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "#define GLOO_HAVE_TRANSPORT_TCP_TLS 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "#define GLOO_HAVE_TRANSPORT_IBVERBS 0", + "#cmakedefine01 GLOO_HAVE_TRANSPORT_UV": "#define GLOO_HAVE_TRANSPORT_UV 0", + "#cmakedefine01 GLOO_USE_AVX": "#define GLOO_USE_AVX __AVX__", +} + +expand_template( + name = "config", + out = "gloo/config.h", + substitutions = substitions, + template = "gloo/config.h.in", +) + +cc_library( + name = "gloo", + srcs = glob( + [ + "gloo/*.cc", + "gloo/common/*.cc", + "gloo/transport/*.cc", + ], + exclude = [ + "gloo/common/linux.cc", + "gloo/common/win.cc", + "gloo/cuda*.cc", + ], + ) + [ + "gloo/rendezvous/context.cc", + "gloo/rendezvous/file_store.cc", + "gloo/rendezvous/hash_store.cc", + "gloo/rendezvous/prefix_store.cc", + "gloo/rendezvous/store.cc", + ] + select({ + "@local_tsl//tsl:macos": [], + "@local_tsl//tsl:windows": [], + "//conditions:default": [ + "gloo/common/linux.cc", + ], + }), + copts = [ + "-fexceptions", + "-Wno-unused-variable", + ], + includes = ["."], + textual_hdrs = glob( + [ + "gloo/*.h", + "gloo/common/*.h", + "gloo/transport/*.h", + ], + exclude = [ + "gloo/cuda*.h", + "gloo/common/win.h", + ], + ) + [ + "gloo/config.h", + "gloo/rendezvous/context.h", + "gloo/rendezvous/file_store.h", + "gloo/rendezvous/hash_store.h", + "gloo/rendezvous/prefix_store.h", + "gloo/rendezvous/store.h", + ], +) + +cc_library( + name = "transport_tcp", + srcs = glob(["gloo/transport/tcp/*.cc"]), + hdrs = glob(["gloo/transport/tcp/*.h"]), + copts = ["-fexceptions"], + deps = [":gloo"], +) diff --git a/third_party/xla/third_party/gloo/workspace.bzl b/third_party/xla/third_party/gloo/workspace.bzl new file mode 100644 index 00000000000000..ede168395acdc5 --- /dev/null +++ b/third_party/xla/third_party/gloo/workspace.bzl @@ -0,0 +1,17 @@ +"""Provides the repository macro to import Gloo.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports Gloo.""" + + GLOO_COMMIT = "5354032ea08eadd7fc4456477f7f7c6308818509" + GLOO_SHA256 = "5759a06e6c8863c58e8ceadeb56f7c701fec89b2559ba33a103a447207bf69c7" + + tf_http_archive( + name = "gloo", + sha256 = GLOO_SHA256, + strip_prefix = "gloo-{commit}".format(commit = GLOO_COMMIT), + urls = tf_mirror_urls("https://github.com/facebookincubator/gloo/archive/{commit}.tar.gz".format(commit = GLOO_COMMIT)), + build_file = "//third_party/gloo:gloo.BUILD", + ) diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch old mode 100644 new mode 100755 index be1c1f0838e9d7..a476720fd2dbd6 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,39 +1,14 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel -@@ -279,6 +279,24 @@ - ) - - cc_library( -+ name = "experimental_ops", -+ srcs = [ -+ "stablehlo/dialect/ExperimentalOps.cpp", -+ ], -+ hdrs = [ -+ "stablehlo/dialect/ExperimentalOps.h", -+ ], -+ strip_include_prefix = ".", -+ deps = [ -+ ":stablehlo_ops", -+ "@llvm-project//llvm:Support", -+ "@llvm-project//mlir:FuncDialect", -+ "@llvm-project//mlir:IR", -+ "@llvm-project//mlir:Support", -+ ], -+) -+ -+cc_library( - name = "interpreter_ops", - srcs = [ - "stablehlo/reference/InterpreterOps.cpp", -@@ -780,6 +798,7 @@ +@@ -890,6 +890,7 @@ + hdrs = [ + "stablehlo/transforms/MapStablehloToVhlo.h", + "stablehlo/transforms/Passes.h", ++ "stablehlo/transforms/StablehloRefineShapes.h", + ], + strip_include_prefix = ".", deps = [ - ":base", - ":chlo_ops", -+ ":experimental_ops", - ":stablehlo_ops", - ":stablehlo_ops_inc_gen", - ":stablehlo_pass_inc_gen", diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -181,32 +156,198 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup -diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir ---- stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir -+++ stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir -@@ -19,6 +19,7 @@ - func.func @iota_dimension_0() -> tensor<4x8xf32> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> -+ // CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array} - %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>) - return %0 : tensor<4x8xf32> - } -@@ -27,6 +28,7 @@ - func.func @iota_dimension_1() -> tensor<4x8xi32> { - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> -+ // CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array} - %0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>) - return %0 : tensor<4x8xi32> - } -diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp ---- stablehlo/stablehlo/dialect/Base.cpp -+++ stablehlo/stablehlo/dialect/Base.cpp -@@ -600,5 +600,18 @@ - return UnrankedTensorType::get(components.getElementType()); - } +diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt +--- stablehlo/stablehlo/CMakeLists.txt ++++ stablehlo/stablehlo/CMakeLists.txt +@@ -15,6 +15,7 @@ + add_subdirectory(api) + add_subdirectory(conversions) + add_subdirectory(dialect) ++add_subdirectory(experimental) + add_subdirectory(integrations) + add_subdirectory(reference) + add_subdirectory(tests) +diff --ruN a/stablehlo/stablehlo/api/PortableApi.h b/stablehlo/stablehlo/api/PortableApi.h +--- stablehlo/stablehlo/api/PortableApi.h ++++ stablehlo/stablehlo/api/PortableApi.h +@@ -27,7 +27,8 @@ + /// Return the current version for portable API. + /// Increments on all meaningful changes to this file. +-inline int64_t getApiVersion() { return 4; } ++/// Or on large breaking source changes that are difficult to integrate. ++inline int64_t getApiVersion() { return 5; } + + // Get the current StableHLO version. + // +diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel +--- stablehlo/stablehlo/experimental/BUILD.bazel ++++ stablehlo/stablehlo/experimental/BUILD.bazel +@@ -0,0 +1,114 @@ ++# Copyright 2023 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") ++ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++cc_library( ++ name = "experimental_base", ++ srcs = [ ++ "dialect/Base.cpp", ++ ], ++ hdrs = [ ++ "dialect/Base.h", ++ ], ++ deps = [ ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "experimental_stablehlo_ops", ++ srcs = [ ++ "dialect/StablehloOps.cpp", ++ ], ++ hdrs = [ ++ "dialect/StablehloOps.h", ++ ], ++ deps = [ ++ ":experimental_base", ++ "//:stablehlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "experimental_stablehlo_pass_inc_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-pass-decls", ++ ], ++ "transforms/Passes.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "transforms/Passes.td", ++ deps = ["@llvm-project//mlir:PassBaseTdFiles"], ++) ++ ++cc_library( ++ name = "experimental_stablehlo_passes", ++ srcs = [ ++ "transforms/StablehloCanonicalizeDynamism.cpp", ++ "transforms/StablehloRefineShapes.cpp", ++ ], ++ hdrs = [ ++ "transforms/Passes.h", ++ ], ++ deps = [ ++ ":experimental_stablehlo_ops", ++ ":experimental_stablehlo_pass_inc_gen", ++ "//:base", ++ "//:chlo_ops", ++ "//:stablehlo_ops", ++ "//:stablehlo_ops_inc_gen", ++ "//:stablehlo_passes", ++ "//:stablehlo_type_inference", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:Transforms", ++ ], ++) ++ ++cc_binary( ++ name = "experimental-stablehlo-opt", ++ srcs = [ ++ "tools/StablehloOptMain.cpp", ++ ], ++ deps = [ ++ ":experimental_stablehlo_passes", ++ "//:interpreter_ops", ++ "//:register", ++ "//:stablehlo_passes", ++ "//:test_utils", ++ "//:tosa_passes", ++ "@llvm-project//mlir:AllExtensions", ++ "@llvm-project//mlir:AllPassesAndDialects", ++ "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:TosaDialect", ++ ], ++) +diff --ruN a/stablehlo/stablehlo/experimental/CMakeLists.txt b/stablehlo/stablehlo/experimental/CMakeLists.txt +--- stablehlo/stablehlo/experimental/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/CMakeLists.txt +@@ -0,0 +1,18 @@ ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++add_subdirectory(dialect) ++add_subdirectory(tests) ++add_subdirectory(tools) ++add_subdirectory(transforms) +diff --ruN a/stablehlo/stablehlo/experimental/dialect/Base.cpp b/stablehlo/stablehlo/experimental/dialect/Base.cpp +--- stablehlo/stablehlo/experimental/dialect/Base.cpp ++++ stablehlo/stablehlo/experimental/dialect/Base.cpp +@@ -0,0 +1,39 @@ ++/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++ Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "stablehlo/experimental/dialect/Base.h" ++ ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypes.h" ++ ++namespace mlir { ++namespace hlo { ++ +DenseIntElementsAttr getPaddingAttr(MLIRContext* context, + ArrayRef values) { + return DenseIntElementsAttr::get( @@ -220,50 +361,97 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/ + return getPaddingAttr(builder->getContext(), values); +} + - } // namespace hlo - } // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h ---- stablehlo/stablehlo/dialect/Base.h -+++ stablehlo/stablehlo/dialect/Base.h -@@ -194,6 +194,10 @@ - - ShapedType createShapedType(ShapedTypeComponents components); - ++} // namespace hlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/experimental/dialect/Base.h b/stablehlo/stablehlo/experimental/dialect/Base.h +--- stablehlo/stablehlo/experimental/dialect/Base.h ++++ stablehlo/stablehlo/experimental/dialect/Base.h +@@ -0,0 +1,35 @@ ++/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. ++ Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_EXPERIMENTAL_DIALECT_BASE_H ++#define STABLEHLO_EXPERIMENTAL_DIALECT_BASE_H ++ ++#include "llvm/ADT/ArrayRef.h" ++#include "mlir/IR/Builders.h" ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/MLIRContext.h" ++ ++namespace mlir { ++namespace hlo { ++ +DenseIntElementsAttr getPaddingAttr(MLIRContext *context, + ArrayRef value); +DenseIntElementsAttr getPaddingAttr(Builder *builder, ArrayRef value); + - // This interface is implemented by both StableHLO and MHLO dialects - // and is used as the foundation for sharing verification, type inference and - // prettyprinting logic between them. -diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt ---- stablehlo/stablehlo/dialect/CMakeLists.txt -+++ stablehlo/stablehlo/dialect/CMakeLists.txt -@@ -77,6 +77,20 @@ - target_include_directories(ChloOps INTERFACE - $ - $ ++} // namespace hlo ++} // namespace mlir ++ ++#endif // STABLEHLO_EXPERIMENTAL_DIALECT_BASE_H +diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt +--- stablehlo/stablehlo/experimental/dialect/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/dialect/CMakeLists.txt +@@ -0,0 +1,42 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++add_mlir_library(ExperimentalStablehloBase ++ PARTIAL_SOURCES_INTENDED ++ Base.cpp ++ ++ LINK_LIBS PUBLIC ++ MLIRIR +) + -+add_mlir_dialect_library(ExperimentalOps ++add_mlir_dialect_library(ExperimentalStablehloOps + PARTIAL_SOURCES_INTENDED -+ ExperimentalOps.cpp ++ StablehloOps.cpp + + DEPENDS + StablehloOpsIncGen + + LINK_LIBS PUBLIC ++ ExperimentalStablehloBase + MLIRFuncDialect + MLIRIR + MLIRSupport + StablehloOps - ) - - add_mlir_dialect_library(StablehloRegister -diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stablehlo/dialect/ExperimentalOps.cpp ---- stablehlo/stablehlo/dialect/ExperimentalOps.cpp -+++ stablehlo/stablehlo/dialect/ExperimentalOps.cpp -@@ -0,0 +1,504 @@ ++) ++ ++target_include_directories(ExperimentalStablehloOps INTERFACE ++ $ ++ $ ++) +diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +@@ -0,0 +1,615 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -279,8 +467,9 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh +limitations under the License. +==============================================================================*/ + -+#include "stablehlo/dialect/ExperimentalOps.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" + ++#include +#include + +#include "llvm/ADT/ArrayRef.h" @@ -293,6 +482,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + +namespace mlir { +namespace stablehlo { ++namespace experimental { + +LogicalResult DynamicReduceWindowOpAdaptor::verify() { + // Before checking the constraints inherited from ReduceWindowOp, @@ -306,8 +496,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + // api_version and backend_config have default values. + // call_target_name should be "stablehlo.dynamic_reduce_window". + // called_computations carries the body. -+ if (attr.getName() != "api_version" && -+ attr.getName() != "backend_config" && ++ if (attr.getName() != "api_version" && attr.getName() != "backend_config" && + attr.getName() != "call_target_name" && + attr.getName() != "called_computations") + return op_.emitError() @@ -688,8 +877,8 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + + // dynamic_top_k_i2 + auto kType = k.getType().dyn_cast(); -+ if (!kType || !kType.hasRank() || -+ kType.getRank() != 0 || !kType.getElementType().isIntOrIndex()) ++ if (!kType || !kType.hasRank() || kType.getRank() != 0 || ++ !kType.getElementType().isIntOrIndex()) + return op_.emitError() + << "expects k (operand #1) " + << "to be a 0-dimensional tensor of integer or index type"; @@ -751,7 +940,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + return op_.getInputs()[1].cast>(); +} + -+ +TypedValue DynamicTopKOpAdaptor::getValues() { + return op_.getResults()[0].cast>(); +} @@ -760,18 +948,129 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + return op_.getResults()[1].cast>(); +} + -+std::optional getDynamicTopKOp( -+ CustomCallOp op) { ++std::optional getDynamicTopKOp(CustomCallOp op) { + if (op.getCallTargetName() != "stablehlo.dynamic_top_k") return {}; + return DynamicTopKOpAdaptor(op); +} + ++LogicalResult TopKOpAdaptor::verify() { ++ if (op_->getNumOperands() != 1) ++ return op_.emitError("expects size(operands) = 1"); ++ if (op_->getNumResults() != 2) ++ return op_.emitError("expects size(results) = 2"); ++ if (!op_.getBackendConfig().empty()) ++ return op_.emitError() << "expects an empty backend_config"; ++ if (op_.getCallTargetName() != "mhlo.topk") ++ return op_.emitError() << "expects @mhlo.topk"; ++ ++ auto operand = op_.getInputs()[0]; ++ auto values = op_.getResults()[0]; ++ auto indices = op_.getResults()[1]; ++ DictionaryAttr topkAttributes = ++ op_->getAttrOfType("mhlo.attributes"); ++ if (!topkAttributes) { ++ return op_.emitError() ++ << "mhlo.attributes missing or not a dictionary attribute"; ++ } ++ ++ IntegerAttr k_attr = topkAttributes.get("k").dyn_cast_or_null(); ++ if (!k_attr) { ++ return op_.emitError() << "mhlo.attributes.k not present or not an integer"; ++ } ++ int64_t k = k_attr.getInt(); ++ ++ // mhlo.topk_c5 ++ if (k < 0) return op_.emitError() << "expects k >= 0"; ++ ++ // mhlo.topk_i1 ++ auto operandType = operand.getType().dyn_cast(); ++ if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 || ++ !operandType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects operand #0 " ++ << "to be a tensor of integer or floating-point type " ++ << "of rank at least 1"; ++ ++ // mhlo.topk_o1 ++ auto valuesType = values.getType().dyn_cast(); ++ if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 || ++ !valuesType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects values (result #0) " ++ << "to be a tensor of integer or floating-point type " ++ << "of rank at least 1"; ++ ++ // mhlo.topk_o2 ++ auto indicesType = indices.getType().dyn_cast(); ++ if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 || ++ !indicesType.getElementType().isSignlessInteger(32)) ++ return op_.emitError() << "expects indices (result #1) " ++ << "to be a tensor of si32 of rank at least 1"; ++ ++ // mhlo.topk_c1 && mhlo.topk_c2 ++ auto operandLastDim = operandType.getRank() - 1; ++ SmallVector expectedValuesShape(operandType.getShape()); ++ expectedValuesShape[operandLastDim] = k; ++ if (failed(verifyCompatibleShape(expectedValuesShape, valuesType.getShape()))) ++ return op_.emitError() << "expects the values shape to match the operand " ++ "shape in all but the last dimension, and " ++ "that the last dimension of the values shape " ++ "has a size k"; ++ ++ // mhlo.topk_c3 ++ if (valuesType.getElementType() != operandType.getElementType()) ++ return op_.emitError() ++ << "expects the values element type to be the same as the operand " ++ << "element type"; ++ ++ // mhlo.topk_c4 ++ if (failed( ++ verifyCompatibleShape(indicesType.getShape(), valuesType.getShape()))) ++ return op_.emitError() ++ << "expects the indices shape to match the values shape"; ++ ++ return success(); ++} ++ ++TypedValue TopKOpAdaptor::getOperand() { ++ return op_.getInputs()[0].cast>(); ++} ++ ++TypedValue TopKOpAdaptor::getValues() { ++ return op_.getResults()[0].cast>(); ++} ++ ++TypedValue TopKOpAdaptor::getIndices() { ++ return op_.getResults()[1].cast>(); ++} ++ ++int64_t TopKOpAdaptor::getK() { ++ DictionaryAttr topkAttributes = ++ op_->getAttrOfType("mhlo.attributes"); ++ return topkAttributes.get("k").cast().getInt(); ++} ++ ++bool TopKOpAdaptor::getLargest() { ++ DictionaryAttr topkAttributes = ++ op_->getAttrOfType("mhlo.attributes"); ++ IntegerAttr largest = ++ topkAttributes.get("largest").dyn_cast_or_null(); ++ ++ return (!largest) ? true : largest.getInt(); ++} ++ ++std::optional getTopKOp(CustomCallOp op) { ++ if (op.getCallTargetName() != "mhlo.topk") return {}; ++ return TopKOpAdaptor(op); ++} ++ ++} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo/dialect/ExperimentalOps.h ---- stablehlo/stablehlo/dialect/ExperimentalOps.h -+++ stablehlo/stablehlo/dialect/ExperimentalOps.h -@@ -0,0 +1,227 @@ +diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.h b/stablehlo/stablehlo/experimental/dialect/StablehloOps.h +--- stablehlo/stablehlo/experimental/dialect/StablehloOps.h ++++ stablehlo/stablehlo/experimental/dialect/StablehloOps.h +@@ -0,0 +1,299 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -787,8 +1086,8 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +limitations under the License. +==============================================================================*/ + -+#ifndef STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -+#define STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H ++#ifndef STABLEHLO_EXPERIMENTAL_DIALECT_STABLEHLO_OPS_H ++#define STABLEHLO_EXPERIMENTAL_DIALECT_STABLEHLO_OPS_H + +// This file supports XLA-specific experiments with the StableHLO opset. +// These experiments are not yet ready to be upstreamed to openxla/stablehlo @@ -805,9 +1104,11 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/experimental/dialect/Base.h" + +namespace mlir { +namespace stablehlo { ++namespace experimental { + +// The DynamicReduceWindowOp experiment provides a dynamic version of +// ReduceWindowOp. Once the dynamism RFC is figured out, we expect to have an @@ -995,55 +1296,253 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +// "stablehlo.dynamic_top_k". +std::optional getDynamicTopKOp(CustomCallOp op); + ++/////////////////// ++// MHLO Op Wrappers ++// There are some ops in MHLO which have experimental support in StableHLO ++// programs by representing them as custom_calls with the target `mhlo.op_name`. ++// The level of support of these ops is similar to the other custom_calls in ++// this file. Generally these ops will be added to StableHLO and their ++// experimental support can be deprecated in favor of op's type inference. ++/////////////////// ++ ++// The TopK experiment provides a StableHLO adapter to MHLO TopKOp. ++// In the future we expect stablehlo.top_k to be added which will use the same ++// refinement rules. ++// ++// Within this experiment, TopKOp is represented via the serialized MHLO ++// `stablehlo.custom_call @mhlo.topk` custom call. ++// ++// The semantics of experimental TopKOp are inherited from the semantics of ++// mhlo.topk. ++// ++// #### Inputs ++// ++// | Label | Name | Type | ++// |-------|-----------------|----------------------------------------------| ++// | (I1) | `operand` | tensor of integer or floating-point type | ++// | (I2) | `k` | constant of type si64 | ++// | (I3) | `largest` | constant of type i1 | ++// ++// #### Outputs ++// ++// | Name | Type | ++// |----------------|------------------------------------------| ++// | `values` | tensor of integer or floating-point type | ++// | `indices` | tensor of si32 type | ++// ++// #### Constraints ++// ++// * (C1) `shape(values)[:-1] = shape(operand)[:-1]` ++// * (C2) `shape(values)[-1] = k` ++// * (C3) `element_type(values) = element_type(operand)` ++// * (C4) `shape(indices) = shape(values)` ++// * (C5) `k >= 0` ++// ++class TopKOpAdaptor { ++ public: ++ TopKOpAdaptor(CustomCallOp op) : op_(op) {} ++ operator Operation*() { return op_; } ++ Operation* operator->() { return op_; } ++ ++ // These accessors assume that the operation is well-formed (i.e. that it ++ // can pass verification). ++ TypedValue getOperand(); ++ TypedValue getValues(); ++ TypedValue getIndices(); ++ int64_t getK(); ++ bool getLargest(); ++ ++ // Verifies the constraints documented above. ++ // Emits errors if errors are detected. ++ LogicalResult verify(); ++ ++ private: ++ CustomCallOp op_; ++}; ++ ++// Wraps a custom call in a TopKOpAdaptor. ++// Fails if the call_target_name of the custom call doesn't match ++// "mhlo.topk". ++std::optional getTopKOp(CustomCallOp op); ++ ++} // namespace experimental +} // namespace stablehlo +} // namespace mlir + -+#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp ---- stablehlo/stablehlo/dialect/StablehloOps.cpp -+++ stablehlo/stablehlo/dialect/StablehloOps.cpp -@@ -1543,6 +1543,7 @@ - p << " across dimensions = ["; - llvm::interleaveComma(getDimensions().getValues(), p); - p << "]"; -+ p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); - p << " : "; - p.printFunctionalType(*this); - } else { -@@ -1705,6 +1706,7 @@ - if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || - parser.parseEqual() || - parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || -+ parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(reduceOpFnType) || - parser.parseOptionalLocationSpecifier(explicitLoc)) - return failure(); -diff --ruN a/stablehlo/stablehlo/tests/print_reduce.mlir b/stablehlo/stablehlo/tests/print_reduce.mlir ---- stablehlo/stablehlo/tests/print_reduce.mlir -+++ stablehlo/stablehlo/tests/print_reduce.mlir -@@ -168,3 +168,15 @@ - - func.return %0: tensor<4xf32> - } ++#endif // STABLEHLO_EXPERIMENTAL_DIALECT_STABLEHLO_OPS_H +diff --ruN a/stablehlo/stablehlo/experimental/tests/BUILD.bazel b/stablehlo/stablehlo/experimental/tests/BUILD.bazel +--- stablehlo/stablehlo/experimental/tests/BUILD.bazel ++++ stablehlo/stablehlo/experimental/tests/BUILD.bazel +@@ -0,0 +1,59 @@ ++# Copyright 2023 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++load("@bazel_skylib//rules:expand_template.bzl", "expand_template") ++load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") ++ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) + -+// The test case makes sure any custom attrs set on the reduce-op are -+// printed/parsed when pretty-printed. ++# Equivalent of configure_lit_site_cfg from CMakeLists.txt. ++expand_template( ++ name = "lit_site_cfg_py_gen", ++ testonly = True, ++ out = "lit.site.cfg.py", ++ substitutions = { ++ "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", ++ "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"), ++ "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ }, ++ template = "lit.site.cfg.py.in", ++) + -+// CHECK-LABEL: func @pretty_print_with_custom_attr -+// CHECK: applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} ++# Equivalent of add_lit_testsuite from CMakeLists.txt. ++[ ++ lit_test( ++ name = "%s.test" % src, ++ size = "small", ++ srcs = [src], ++ data = [ ++ "lit.cfg.py", ++ "lit.site.cfg.py", ++ "//:stablehlo-opt", ++ "//:stablehlo-translate", ++ "//stablehlo/experimental:experimental-stablehlo-opt", ++ "@llvm-project//llvm:FileCheck", ++ "@llvm-project//llvm:not", ++ ] + glob(["%s.bc" % src]), ++ tags = ["stablehlo_tests"], ++ ) ++ for src in glob(["**/*.mlir"]) ++] ++ ++test_suite( ++ name = "experimental_stablehlo_tests", ++ tags = ["experimental_stablehlo_tests"], ++) +diff --ruN a/stablehlo/stablehlo/experimental/tests/CMakeLists.txt b/stablehlo/stablehlo/experimental/tests/CMakeLists.txt +--- stablehlo/stablehlo/experimental/tests/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/tests/CMakeLists.txt +@@ -0,0 +1,29 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++configure_lit_site_cfg( ++ ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ++ ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py ++ MAIN_CONFIG ++ ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py ++) ++add_lit_testsuite(check-experimental-stablehlo-tests "Running the experimental/tests/ suite" ++ ${CMAKE_CURRENT_BINARY_DIR} ++ DEPENDS ++ FileCheck ++ experimental-stablehlo-opt ++ stablehlo-translate ++) ++add_dependencies(check-stablehlo-quick check-experimental-stablehlo-tests) +diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.cfg.py b/stablehlo/stablehlo/experimental/tests/lit.cfg.py +--- stablehlo/stablehlo/experimental/tests/lit.cfg.py ++++ stablehlo/stablehlo/experimental/tests/lit.cfg.py +@@ -0,0 +1,42 @@ ++"""Lit configuration to drive test in this repo.""" ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++# -*- Python -*- ++# pylint: disable=undefined-variable ++ ++import os ++ ++import lit.formats ++from lit.llvm import llvm_config ++ ++# Populate Lit configuration with the minimal required metadata. ++# Some metadata is populated in lit.site.cfg.py.in. ++config.name = 'STABLEHLO_TESTS_SUITE' ++config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) ++config.suffixes = ['.mlir'] ++config.test_source_root = os.path.dirname(__file__) ++ ++# Make LLVM and StableHLO tools available in RUN directives ++tools = [ ++ 'FileCheck', ++ 'experimental-stablehlo-opt', ++ 'stablehlo-translate', ++ 'not', ++] ++tool_dirs = [ ++ config.llvm_tools_dir, ++ config.stablehlo_tools_dir, ++] ++llvm_config.add_tool_substitutions(tools, tool_dirs) +diff --ruN a/stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in b/stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in +--- stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in ++++ stablehlo/stablehlo/experimental/tests/lit.site.cfg.py.in +@@ -0,0 +1,21 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++@LIT_SITE_CFG_IN_HEADER@ ++ ++import lit.llvm ++lit.llvm.initialize(lit_config, config) ++config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" ++config.stablehlo_tools_dir = "@STABLEHLO_TOOLS_DIR@" ++lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@" + "/stablehlo/experimental/tests/lit.cfg.py") +diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir +--- stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir ++++ stablehlo/stablehlo/experimental/tests/stablehlo_canonicalize_dynamism.mlir +@@ -0,0 +1,344 @@ ++// RUN: experimental-stablehlo-opt --experimental-stablehlo-canonicalize-dynamism --split-input-file --verify-diagnostics %s | FileCheck %s + -+func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> { -+ %0 = stablehlo.constant dense<0.000000e+00> : tensor -+ %1 = stablehlo.reduce(%arg0 init: %0) applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor) -> tensor<2x13xf32> -+ return %1 : tensor<2x13xf32> -+} -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir ---- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -+++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -@@ -426,6 +426,172 @@ - - // ----- - +// CHECK-LABEL: func @dynamic_reduce_window_success_static_result_type +func.func @dynamic_reduce_window_success_static_result_type(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<2x2xf32> { + // CHECK-NOT: stablehlo.dynamic_reduce_window @@ -1209,17 +1708,6 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st +} + +// ----- -+ - // CHECK-LABEL: func @dynamic_reshape_success - func.func @dynamic_reshape_success(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NOT: stablehlo.dynamic_reshape -@@ -452,6 +618,185 @@ - %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x?xf32> - return %1 : tensor<1x?xf32> -+} -+ -+// ----- + +// CHECK-LABEL: func @dynamic_rng_bit_generator_success +func.func @dynamic_rng_bit_generator_success(%arg0: tensor<2xui64>) -> tensor<1x4xf32> { @@ -1396,16 +1884,13 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st + %k = stablehlo.constant dense<3> : tensor + %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<4xi32>) + return %1#0, %1#1 : tensor<3xf32>, tensor<4xi32> - } - - // ----- -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -607,12 +607,55 @@ - - // ----- - ++} +diff --ruN a/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/experimental/tests/stablehlo_refine_shapes.mlir +@@ -0,0 +1,152 @@ ++// RUN: experimental-stablehlo-opt --experimental-stablehlo-refine-shapes --split-input-file --verify-diagnostics %s | FileCheck %s ++ +// CHECK-LABEL: @main +func.func @main(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<*xf32> { + // CHECK: stablehlo.dynamic_reduce_window{{.*}} -> tensor<2x2xf32> @@ -1426,16 +1911,6 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/ +} + +// ----- -+ - // CHECK-LABEL: @refine_dynamic_reshape - func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<*xf32> { - // CHECK: stablehlo.dynamic_reshape{{.*}} -> tensor<1x4xf32> - %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -+} -+ -+// ----- + +// CHECK-LABEL: @refine_dynamic_rng_bit_generator +func.func @refine_dynamic_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor, tensor<*xf32>) { @@ -1455,36 +1930,374 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/ + %k = stablehlo.constant dense<4> : tensor + %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor, tensor) + return %1#0, %1#1 : tensor, tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -25,6 +25,7 @@ - For example, if the output_shape operand of DynamicReshapeOp is a constant - value, then the operation can be transformed to ReshapeOp. - }]; ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_topk ++func.func @refine_mhlo_topk(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // CHECK: mhlo.topk{{.*}} -> (tensor<5x4xf32>, tensor<5x4xi32>) ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_too_many_operands ++func.func @refine_mhlo_error_too_many_operands(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects size(operands) = 1}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0, %arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>, tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_too_few_results ++func.func @refine_mhlo_error_too_few_results(%arg0: tensor<5x16xf32>) -> (tensor) { ++ // expected-error@+1{{expects size(results) = 2}} ++ %0 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor) ++ return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_wrong_output_1_type ++func.func @refine_mhlo_error_wrong_output_1_type(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects values (result #0) to be a tensor of integer or floating-point type of rank at least 1}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_wrong_output_2_type ++func.func @refine_mhlo_error_wrong_output_2_type(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects indices (result #1) to be a tensor of si32 of rank at least 1}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c1_wrong_output_shape ++func.func @refine_mhlo_error_c1_wrong_output_shape(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the values shape to match the operand}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c2_last_dim_not_k ++func.func @refine_mhlo_error_c2_last_dim_not_k(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the values shape to match the operand}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c3_wrong_output_type ++func.func @refine_mhlo_error_c3_wrong_output_type(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the values element type to be the same as the operand element type}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c4_outputs_shape_mismatch ++func.func @refine_mhlo_error_c4_outputs_shape_mismatch(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects the indices shape to match the values shape}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = 4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_mhlo_error_c5_negative_k ++func.func @refine_mhlo_error_c5_negative_k(%arg0: tensor<5x16xf32>) -> (tensor, tensor) { ++ // expected-error@+1{{expects k >= 0}} ++ %0:2 = stablehlo.custom_call @mhlo.topk(%arg0) { ++ mhlo.attributes = { k = -4 : i64, largest = true} ++ } : (tensor<5x16xf32>) -> (tensor, tensor) ++ return %0#0, %0#1 : tensor, tensor ++} +diff --ruN a/stablehlo/stablehlo/experimental/tools/CMakeLists.txt b/stablehlo/stablehlo/experimental/tools/CMakeLists.txt +--- stablehlo/stablehlo/experimental/tools/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/tools/CMakeLists.txt +@@ -0,0 +1,41 @@ ++# Copyright 2020 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++set(LLVM_OPTIONAL_SOURCES ++ StablehloOptMain.cpp ++) ++ ++# stablehlo-opt ++get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) ++get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) ++get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) ++set(LIBS ++ ${dialect_libs} ++ ${conversion_libs} ++ ${extension_libs} ++ ExperimentalStablehloPasses ++ MLIROptLib ++ StablehloRegister ++ StablehloTestUtils ++ StablehloPasses ++ InterpreterOps ++ StablehloTOSATransforms ++ ) ++add_llvm_executable(experimental-stablehlo-opt StablehloOptMain.cpp) ++llvm_update_compile_flags(experimental-stablehlo-opt) ++target_link_libraries(experimental-stablehlo-opt PRIVATE ${LIBS}) ++ ++mlir_check_all_link_libraries(experimental-stablehlo-opt) ++ +diff --ruN a/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp b/stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp +--- stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp ++++ stablehlo/stablehlo/experimental/tools/StablehloOptMain.cpp +@@ -0,0 +1,46 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "mlir/Dialect/Tosa/IR/TosaOps.h" ++#include "mlir/Dialect/Tosa/Transforms/Passes.h" ++#include "mlir/InitAllDialects.h" ++#include "mlir/InitAllExtensions.h" ++#include "mlir/InitAllPasses.h" ++#include "mlir/Tools/mlir-opt/MlirOptMain.h" ++#include "stablehlo/conversions/tosa/transforms/Passes.h" ++#include "stablehlo/dialect/Register.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/reference/InterpreterOps.h" ++#include "stablehlo/tests/TestUtils.h" ++#include "stablehlo/transforms/Passes.h" ++ ++int main(int argc, char **argv) { ++ mlir::registerAllPasses(); ++ mlir::hlo::registerAllTestPasses(); ++ mlir::stablehlo::registerPassPipelines(); ++ mlir::stablehlo::registerPasses(); ++ mlir::stablehlo::experimental::registerPasses(); ++ mlir::tosa::registerStablehloLegalizeToTosaPassPass(); ++ mlir::tosa::registerStablehloPrepareForTosaPassPass(); ++ ++ mlir::DialectRegistry registry; ++ mlir::registerAllDialects(registry); ++ mlir::registerAllExtensions(registry); ++ mlir::stablehlo::registerAllDialects(registry); ++ registry.insert(); ++ ++ return failed( ++ mlir::MlirOptMain(argc, argv, "Experimental StableHLO optimizer driver\n", registry)); ++} +diff --ruN a/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt b/stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +--- stablehlo/stablehlo/experimental/transforms/CMakeLists.txt ++++ stablehlo/stablehlo/experimental/transforms/CMakeLists.txt +@@ -0,0 +1,39 @@ ++# Copyright 2023 The StableHLO Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++set(LLVM_TARGET_DEFINITIONS Passes.td) ++mlir_tablegen(Passes.h.inc -gen-pass-decls) ++add_public_tablegen_target(ExperimentalPassesIncGen) ++ ++add_mlir_dialect_library(ExperimentalStablehloPasses ++ PARTIAL_SOURCES_INTENDED ++ StablehloCanonicalizeDynamism.cpp ++ StablehloRefineShapes.cpp ++ ++ DEPENDS ++ ExperimentalPassesIncGen ++ ++ LINK_LIBS PUBLIC ++ ChloOps ++ MLIRFuncDialect ++ MLIRIR ++ MLIRInferTypeOpInterface ++ MLIRSupport ++ MLIRTransformUtils ++ ExperimentalStablehloOps ++ StablehloBase ++ StablehloOps ++ StablehloPasses ++ StablehloTypeInference ++) +diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/stablehlo/experimental/transforms/Passes.h +--- stablehlo/stablehlo/experimental/transforms/Passes.h ++++ stablehlo/stablehlo/experimental/transforms/Passes.h +@@ -0,0 +1,37 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_EXPERIMENTAL_TRANSFORMS_PASSES_H ++#define STABLEHLO_EXPERIMENTAL_TRANSFORMS_PASSES_H ++ ++#include ++ ++#include "mlir/Pass/Pass.h" ++#include "mlir/Transforms/DialectConversion.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS ++#define GEN_PASS_DECL_STABLEHLOREFINESHAPESPASS ++#define GEN_PASS_REGISTRATION ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir ++ ++#endif // STABLEHLO_EXPERIMENTAL_TRANSFORMS_PASSES_H +diff --ruN a/stablehlo/stablehlo/experimental/transforms/Passes.td b/stablehlo/stablehlo/experimental/transforms/Passes.td +--- stablehlo/stablehlo/experimental/transforms/Passes.td ++++ stablehlo/stablehlo/experimental/transforms/Passes.td +@@ -0,0 +1,31 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++include "mlir/Pass/PassBase.td" ++ ++def StablehloCanonicalizeDynamismPass : Pass<"experimental-stablehlo-canonicalize-dynamism", "func::FuncOp"> { ++ let summary = "(Experimental) Canonicalizes dynamic StableHLO ops into static ops."; ++ let description = [{ ++ Experimental version of the --stablehlo-canonicalize-dynamism pass. ++ }]; + let dependentDialects = ["mlir::chlo::ChloDialect"]; - } - - def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ---- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -+++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -@@ -24,6 +24,8 @@ - #include "mlir/Interfaces/InferTypeOpInterface.h" - #include "mlir/Support/LogicalResult.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++} ++ ++def StablehloRefineShapesPass : Pass<"experimental-stablehlo-refine-shapes", "ModuleOp"> { ++ let summary = "(Experimental) Refines shapes across a StableHLO program."; ++ let description = [{ ++ Experimental version of the --stablehlo-refine-shapes pass. ++ }]; ++} +diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDynamism.cpp +@@ -0,0 +1,167 @@ ++/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. ++ Copyright 2023 The StableHLO Authors. ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include ++ ++#include "llvm/ADT/STLExtras.h" ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/ChloOps.h" -+#include "stablehlo/dialect/ExperimentalOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/transforms/Passes.h" - -@@ -198,6 +200,54 @@ - } - }; - ++#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DEF_STABLEHLOCANONICALIZEDYNAMISMPASS ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++namespace { ++ +struct CanonicalizeDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -1532,17 +2345,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/ + return success(); + } +}; -+ - struct CanonicalizeDynamicReshapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; -@@ -210,6 +260,56 @@ - if (!op.getType().hasStaticShape()) - return rewriter.notifyMatchFailure(op, "expected static result type"); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); -+ return success(); -+ } -+}; + +struct CanonicalizeDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { @@ -1590,35 +2392,84 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/ + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getOperand(), k[0]); - return success(); - } - }; -@@ -320,7 +420,10 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++ return success(); ++ } ++}; ++ ++struct StablehloCanonicalizeDynamismPass ++ : public impl::StablehloCanonicalizeDynamismPassBase< ++ StablehloCanonicalizeDynamismPass> { ++ using StablehloCanonicalizeDynamismPassBase:: ++ StablehloCanonicalizeDynamismPassBase; ++ ++ void runOnOperation() override { ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 2; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::AnyOp; ++ ++ RewritePatternSet patterns(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + patterns.add(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); - patterns.add( - &getContext()); - patterns.add(&getContext()); -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -43,6 +43,7 @@ - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "stablehlo/dialect/Base.h" - #include "stablehlo/dialect/ChloOps.h" -+#include "stablehlo/dialect/ExperimentalOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/dialect/TypeInference.h" - #include "stablehlo/transforms/Passes.h" -@@ -844,12 +845,97 @@ - } - }; - ++ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), ++ config))) { ++ return signalPassFailure(); ++ } ++ } ++}; ++ ++} // namespace ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp +@@ -0,0 +1,178 @@ ++/* Copyright 2022 The StableHLO Authors. ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "stablehlo/transforms/StablehloRefineShapes.h" ++ ++#include ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++#include "stablehlo/dialect/Base.h" ++#include "stablehlo/dialect/StablehloOps.h" ++#include "stablehlo/dialect/TypeInference.h" ++#include "stablehlo/experimental/dialect/StablehloOps.h" ++#include "stablehlo/experimental/transforms/Passes.h" ++#include "stablehlo/transforms/Passes.h" ++ ++namespace mlir { ++namespace stablehlo { ++namespace experimental { ++ ++#define GEN_PASS_DEF_STABLEHLOREFINESHAPESPASS ++#include "stablehlo/experimental/transforms/Passes.h.inc" ++ ++namespace { ++ +struct RefineDynamicReduceWindowOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -1660,15 +2511,6 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + return refineReturnTypes(rewriter, op, inferredReturnTypes); + } +}; -+ - struct RefineDynamicReshapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DynamicReshapeOp op, - PatternRewriter& rewriter) const override { - return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; + +struct RefineDynamicRngBitGeneratorOpPattern + : public OpRewritePattern { @@ -1710,18 +2552,908 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + + outputShape[operandType.getRank() - 1] = k[0]; + return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); - } - }; - -@@ -1181,7 +1267,10 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++ } ++}; ++ ++struct RefineTopKOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getTopKOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ TopKOpAdaptor op = *maybeOp; ++ ++ auto operandType = op.getOperand().getType().cast(); ++ SmallVector outputShape(operandType.getShape()); ++ outputShape.back() = op.getK(); ++ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); ++ } ++}; ++ ++struct StablehloRefineShapesPass ++ : public impl::StablehloRefineShapesPassBase { ++ using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; ++ ++ void runOnOperation() override { ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); ++ ++ // The algorithm behind this pass consists of a single traversal of the ++ // function. This is sufficient because we only support one function per ++ // program at the moment. ++ // TODO(#1048): Find out why .maxIterations = 1 no longer works. ++ // There have been recent refactors to applyPatternsAndFoldGreedily ++ // upstream, and that might be the reason. ++ GreedyRewriteConfig config; ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = true; ++ config.maxIterations = 2; ++ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; ++ config.strictMode = GreedyRewriteStrictness::AnyOp; ++ ++ RewritePatternSet patterns(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + patterns.add(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); ++ patterns.add(&getContext()); ++ if (failed( ++ applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { ++ return signalPassFailure(); ++ } ++ } ++}; ++ ++} // namespace ++} // namespace experimental ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/infer_chlo.mlir b/stablehlo/stablehlo/tests/infer_chlo.mlir +--- stablehlo/stablehlo/tests/infer_chlo.mlir ++++ stablehlo/stablehlo/tests/infer_chlo.mlir +@@ -120,10 +120,10 @@ + // ----- + // CHECK-LABEL: @broadcast_select_reify + func.func @broadcast_select_reify(%arg0: tensor<2xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1xindex> { +- // CHECK: %0 = shape.const_shape [2] : tensor<1xindex> ++ // CHECK: %0 = shape.shape_of %arg0 : tensor<2xi1> -> tensor<1xindex> + // CHECK-NEXT: %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> + // CHECK-NEXT: %2 = shape.shape_of %arg2 : tensor -> tensor<1xindex> +- // CHECK-NEXT: %3 = shape.broadcast %1, %2, %0 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> ++ // CHECK-NEXT: %3 = shape.broadcast %0, %1, %2 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex> + %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor, tensor) -> tensor + %1 = "hlo_test_infer.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + return %1: tensor<1xindex> +diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h +--- stablehlo/stablehlo/transforms/Passes.h ++++ stablehlo/stablehlo/transforms/Passes.h +@@ -18,9 +18,12 @@ + + #include + ++#include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/Dialect/Quant/QuantOps.h" + #include "mlir/Dialect/Shape/IR/Shape.h" ++#include "mlir/IR/BuiltinOps.h" + #include "mlir/Pass/Pass.h" ++#include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/DialectConversion.h" + + namespace mlir { +@@ -34,6 +37,14 @@ + #define GEN_PASS_DECL_VHLOTOVERSIONPASS + #define GEN_PASS_REGISTRATION + #include "stablehlo/transforms/Passes.h.inc" ++ ++// Populates --stablehlo-canonicalize-dynamism patterns. ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); ++ ++// Populates --stablehlo-refine-shapes patterns. ++void populateStablehloRefineShapesPatterns(RewritePatternSet *patterns, ++ MLIRContext *context); + + // Populates StableHLO ops to VHLO ops rewriting patterns. + void populateStablehloToVhloPatterns(RewritePatternSet *patterns, +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -307,16 +307,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add( +- &getContext()); +- patterns.add(&getContext()); ++ populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); +@@ -325,5 +316,19 @@ + }; + + } // namespace ++ ++void populateStablehloCanonicalizeDynamismPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -11,6 +11,8 @@ + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ ++ ++#include "stablehlo/transforms/StablehloRefineShapes.h" + + #include + #include +@@ -53,6 +55,193 @@ + #define GEN_PASS_DEF_STABLEHLOREFINESHAPESPASS + #include "stablehlo/transforms/Passes.h.inc" + ++LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, ++ ValueRange values, TypeRange types) { ++ if (values.size() != types.size()) ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "refineValues failed for " << types << ": expected " ++ << values.size() << " types, got " << types.size(); ++ }); ++ ++ // Check whether `types` contain any new information with respect to existing ++ // return types. Even if just a single dimension size out of an entire tensor ++ // type got updated, using `inferMostSpecificType` ensures that we don't ++ // miss that. ++ bool needsRefinement = false; ++ SmallVector refinedTypes; ++ for (auto it : llvm::zip(values.getTypes(), types)) { ++ // Cannot use structured bindings to simplify this because capturing ++ // structured bindings in a lambda is a C++ 20 extension. ++ auto currentType = std::get<0>(it); ++ auto refinement = std::get<1>(it); ++ auto refinedType = hlo::inferMostSpecificType( ++ /*location=*/{}, {currentType, refinement}); ++ if (failed(refinedType)) ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "inferMostSpecificType failed for " << currentType << " and " ++ << refinement; ++ }); ++ refinedTypes.push_back(*refinedType); ++ needsRefinement |= (currentType != *refinedType); ++ } ++ if (!needsRefinement) ++ return rewriter.notifyMatchFailure(op, "doesn't need refinement"); ++ ++ for (auto it : llvm::zip(values, refinedTypes)) { ++ // Cannot use structured bindings to simplify this because capturing ++ // structured bindings in a lambda is a C++ 20 extension. ++ auto value = std::get<0>(it); ++ auto refinedType = std::get<1>(it); ++ if (value.getType() == refinedType) continue; ++ ++ // Check whether the users of this value are ready for the type of the ++ // value to be refined. ++ for (Operation* user : value.getUsers()) { ++ // CHLO and StableHLO ops are designed to support type refinements of ++ // their operands and results. Any operand type in these ops can change ++ // within what's supported by `inferMostSpecificType` without breaking ++ // verification of the op. ++ if (isa(user->getDialect())) ++ continue; ++ ++ // Simply changing operand type of `func.return` won't work because ++ // that won't update the FunctionType of the enclosing `func.func`. ++ // Nonetheless, we still want to support these ops because they are widely ++ // used in StableHLO programs (although the plan of record is to replace ++ // `func.return` ops in StableHLO programs with `stablehlo.return`: ++ // https://github.com/openxla/stablehlo/issues/425). ++ if (isa(user)) continue; ++ ++ // Unlike in TensorFlow's type inference pass, here we work only with ++ // allowlisted ops to focus our support on well-defined semantics of ++ // StableHLO programs. ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "unsupported refinement: tried to refine " << value.getType() ++ << " to " << refinedType << " for user " << user; ++ }); ++ } ++ ++ // Happy path: simply call setType here because most of our users are ++ // fine with that. ++ auto unrefinedType = value.getType(); ++ value.setType(refinedType); ++ ++ // Special case: for `func.return`, guard the refinement with a cast ++ // and leave propagation of the refined return type to a dedicated pattern. ++ auto isFuncReturn = [](OpOperand& use) -> bool { ++ return isa(use.getOwner()); ++ }; ++ if (llvm::none_of(value.getUses(), isFuncReturn)) continue; ++ rewriter.setInsertionPointAfter(op); ++ auto castToUnrefinedType = rewriter.create( ++ op->getLoc(), unrefinedType, value); ++ value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); ++ } ++ ++ return success(); ++} ++ ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef types) { ++ if (failed(refineValues(rewriter, op, op->getResults(), types))) ++ return failure(); ++ ++ // This `replaceOpWithIf` call doesn't actually change the IR, but ++ // it does ask the rewriter to visit all the users of this op. There is no ++ // upstream API to achieve this directly, but if it's introduced in the ++ // future, we could use it here. ++ rewriter.replaceOpWithIf(op, op->getResults(), ++ [](OpOperand& use) { return false; }); ++ return success(); ++} ++ ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef refinements) { ++ SmallVector flattenedTypes; ++ hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); ++ auto flattenedSize = flattenedTypes.size(); ++ if (flattenedSize != refinements.size()) ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "refineReturnTypes failed: expected " << flattenedSize ++ << " refinements, got " << refinements.size(); ++ }); ++ ++ SmallVector flattenedRefinedTypes; ++ for (auto it : llvm::zip(flattenedTypes, refinements)) { ++ // Cannot use structured bindings to simplify this because capturing ++ // structured bindings in a lambda is a C++ 20 extension. ++ ShapedType currentType = std::get<0>(it).dyn_cast(); ++ ShapedTypeComponents refinement = std::get<1>(it); ++ auto failWithReason = [&](StringRef reason) { ++ return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { ++ diag << "refineTypes failed: refining " << currentType ++ << "with refinement: {"; ++ if (refinement.hasRank()) { ++ diag << "shape = [" << refinement.getDims() << "]"; ++ if (refinement.getAttribute()) ++ diag << "attribute = " << refinement.getAttribute(); ++ } else { ++ diag << "hasRank = false"; ++ } ++ diag << ", elementType = " << refinement.getElementType(); ++ diag << "} failed: " << reason; ++ }); ++ }; ++ ++ // If the current type is not a shaped type, then the refinement must ++ // be completely empty. ++ if (!currentType) { ++ if (refinement.hasRank() || refinement.getElementType() || ++ refinement.getAttribute()) ++ return failWithReason("unsupported refinement"); ++ flattenedRefinedTypes.push_back(currentType); ++ continue; ++ } ++ ++ // If the refinement has an element type, then it must be the same as ++ // the current element type. ++ Type currentElementType = currentType.getElementType(); ++ if (refinement.getElementType() && ++ currentElementType != refinement.getElementType()) ++ return failWithReason("expected compatible element types"); ++ ++ // If neither the current type nor the refinement are ranked, then there's ++ // nothing to refine, and we return the current type. ++ bool hasRank = currentType.hasRank() || refinement.hasRank(); ++ if (!hasRank) { ++ flattenedRefinedTypes.push_back(currentType); ++ continue; ++ } ++ ++ // If either the current type or the refinement have encodings, then ++ // we fail. Encodings are left for future work. ++ Attribute currentEncoding = nullptr; ++ if (auto currentRankedType = currentType.dyn_cast()) { ++ currentEncoding = currentRankedType.getEncoding(); ++ } ++ Attribute refinedEncoding = refinement.getAttribute(); ++ if (currentEncoding || refinedEncoding) ++ return failWithReason("expected compatible encodings"); ++ ++ // If both the current type and the refinement have shapes, use the shape ++ // from the refinement. Otherwise, pick whatever is available. ++ // Make sure that the resulting type is compatible with the current type ++ // to avoid creating invalid code. ++ auto refinedShape = ++ refinement.hasRank() ? refinement.getDims() : currentType.getShape(); ++ auto refinedType = RankedTensorType::get(refinedShape, currentElementType); ++ if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) ++ return failWithReason("expected compatible shapes"); ++ flattenedRefinedTypes.push_back(refinedType); ++ } ++ ++ SmallVector refinedTypes; ++ if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), ++ flattenedRefinedTypes, refinedTypes))) ++ return failure(); ++ return refineReturnTypes(rewriter, op, refinedTypes); ++} ++ + namespace { + + // DenseElementsAttr can be constructed from ArrayRef but not from +@@ -422,245 +611,6 @@ + // StableHLO-specific extension to refine return types based on potentially + // refined operands. + +-// Refines the values using the given types. +-// Tricky implementation details: +-// 1) Need to support partial shape refinements, e.g. if just a single +-// dimension size out of an entire tensor type got refined. This is done +-// via inferMostSpecificType. +-// 2) Need to signal propagation of the refined shapes across the +-// StableHLO program. Different callers of this function have different +-// propagation needs, so this function doesn't signal anything on its own +-// and leaves that to the callers. +-LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, +- ValueRange values, TypeRange types) { +- if (values.size() != types.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineValues failed for " << types << ": expected " +- << values.size() << " types, got " << types.size(); +- }); +- +- // Check whether `types` contain any new information with respect to existing +- // return types. Even if just a single dimension size out of an entire tensor +- // type got updated, using `inferMostSpecificType` ensures that we don't +- // miss that. +- bool needsRefinement = false; +- SmallVector refinedTypes; +- for (auto it : llvm::zip(values.getTypes(), types)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto currentType = std::get<0>(it); +- auto refinement = std::get<1>(it); +- auto refinedType = hlo::inferMostSpecificType( +- /*location=*/{}, {currentType, refinement}); +- if (failed(refinedType)) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "inferMostSpecificType failed for " << currentType << " and " +- << refinement; +- }); +- refinedTypes.push_back(*refinedType); +- needsRefinement |= (currentType != *refinedType); +- } +- if (!needsRefinement) +- return rewriter.notifyMatchFailure(op, "doesn't need refinement"); +- +- for (auto it : llvm::zip(values, refinedTypes)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- auto value = std::get<0>(it); +- auto refinedType = std::get<1>(it); +- if (value.getType() == refinedType) continue; +- +- // Check whether the users of this value are ready for the type of the +- // value to be refined. +- for (Operation* user : value.getUsers()) { +- // CHLO and StableHLO ops are designed to support type refinements of +- // their operands and results. Any operand type in these ops can change +- // within what's supported by `inferMostSpecificType` without breaking +- // verification of the op. +- if (isa(user->getDialect())) +- continue; +- +- // Simply changing operand type of `func.return` won't work because +- // that won't update the FunctionType of the enclosing `func.func`. +- // Nonetheless, we still want to support these ops because they are widely +- // used in StableHLO programs (although the plan of record is to replace +- // `func.return` ops in StableHLO programs with `stablehlo.return`: +- // https://github.com/openxla/stablehlo/issues/425). +- if (isa(user)) continue; +- +- // Unlike in TensorFlow's type inference pass, here we work only with +- // allowlisted ops to focus our support on well-defined semantics of +- // StableHLO programs. +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "unsupported refinement: tried to refine " << value.getType() +- << " to " << refinedType << " for user " << user; +- }); +- } +- +- // Happy path: simply call setType here because most of our users are +- // fine with that. +- auto unrefinedType = value.getType(); +- value.setType(refinedType); +- +- // Special case: for `func.return`, guard the refinement with a cast +- // and leave propagation of the refined return type to a dedicated pattern. +- auto isFuncReturn = [](OpOperand& use) -> bool { +- return isa(use.getOwner()); +- }; +- if (llvm::none_of(value.getUses(), isFuncReturn)) continue; +- rewriter.setInsertionPointAfter(op); +- auto castToUnrefinedType = rewriter.create( +- op->getLoc(), unrefinedType, value); +- value.replaceUsesWithIf(castToUnrefinedType.getOutputs()[0], isFuncReturn); +- } +- +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef types) { +- if (failed(refineValues(rewriter, op, op->getResults(), types))) +- return failure(); +- +- // This `replaceOpWithIf` call doesn't actually change the IR, but +- // it does ask the rewriter to visit all the users of this op. There is no +- // upstream API to achieve this directly, but if it's introduced in the +- // future, we could use it here. +- rewriter.replaceOpWithIf(op, op->getResults(), +- [](OpOperand& use) { return false; }); +- return success(); +-} +- +-// Refines the return types of the given operation using the given types. +-// Tricky implementation details: +-// 1) `types` can include non-shaped types. If there are tuple types, +-// then they are first flattened into non-tuple types using in-order +-// traversal, and only then we apply the refinements. If there are other +-// types, then the corresponding refinements must be completely empty. +-// 2) Encodings are not supported. In principle, TypeExtensions should be +-// supportable, but this needs careful thinking through. Given that no one +-// asked for support for bounded dynamism in this pass yet, this is left +-// for future work. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, +- ArrayRef refinements) { +- SmallVector flattenedTypes; +- hlo::flattenTupleTypes(op->getResultTypes(), flattenedTypes); +- auto flattenedSize = flattenedTypes.size(); +- if (flattenedSize != refinements.size()) +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineReturnTypes failed: expected " << flattenedSize +- << " refinements, got " << refinements.size(); +- }); +- +- SmallVector flattenedRefinedTypes; +- for (auto it : llvm::zip(flattenedTypes, refinements)) { +- // Cannot use structured bindings to simplify this because capturing +- // structured bindings in a lambda is a C++ 20 extension. +- ShapedType currentType = std::get<0>(it).dyn_cast(); +- ShapedTypeComponents refinement = std::get<1>(it); +- auto failWithReason = [&](StringRef reason) { +- return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { +- diag << "refineTypes failed: refining " << currentType +- << "with refinement: {"; +- if (refinement.hasRank()) { +- diag << "shape = [" << refinement.getDims() << "]"; +- if (refinement.getAttribute()) +- diag << "attribute = " << refinement.getAttribute(); +- } else { +- diag << "hasRank = false"; +- } +- diag << ", elementType = " << refinement.getElementType(); +- diag << "} failed: " << reason; +- }); +- }; +- +- // If the current type is not a shaped type, then the refinement must +- // be completely empty. +- if (!currentType) { +- if (refinement.hasRank() || refinement.getElementType() || +- refinement.getAttribute()) +- return failWithReason("unsupported refinement"); +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If the refinement has an element type, then it must be the same as +- // the current element type. +- Type currentElementType = currentType.getElementType(); +- if (refinement.getElementType() && +- currentElementType != refinement.getElementType()) +- return failWithReason("expected compatible element types"); +- +- // If neither the current type nor the refinement are ranked, then there's +- // nothing to refine, and we return the current type. +- bool hasRank = currentType.hasRank() || refinement.hasRank(); +- if (!hasRank) { +- flattenedRefinedTypes.push_back(currentType); +- continue; +- } +- +- // If either the current type or the refinement have encodings, then +- // we fail. Encodings are left for future work. +- Attribute currentEncoding = nullptr; +- if (auto currentRankedType = currentType.dyn_cast()) { +- currentEncoding = currentRankedType.getEncoding(); +- } +- Attribute refinedEncoding = refinement.getAttribute(); +- if (currentEncoding || refinedEncoding) +- return failWithReason("expected compatible encodings"); +- +- // If both the current type and the refinement have shapes, use the shape +- // from the refinement. Otherwise, pick whatever is available. +- // Make sure that the resulting type is compatible with the current type +- // to avoid creating invalid code. +- auto refinedShape = +- refinement.hasRank() ? refinement.getDims() : currentType.getShape(); +- auto refinedType = RankedTensorType::get(refinedShape, currentElementType); +- if (!hlo::isCompatibleForHloTypeInference(currentType, refinedType)) +- return failWithReason("expected compatible shapes"); +- flattenedRefinedTypes.push_back(refinedType); +- } +- +- SmallVector refinedTypes; +- if (failed(hlo::unflattenTupleTypes(op->getResultTypes(), +- flattenedRefinedTypes, refinedTypes))) +- return failure(); +- return refineReturnTypes(rewriter, op, refinedTypes); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- ArrayRef shape) { +- return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); +-} +- +-// Refines the return type of the given operation using the given shape. +-// This function also signals PatternRewriter that it needs to visit all the +-// users of this op if any updates to its results have happened during execution +-// of the function. +-template +-LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, +- Value shapeValue) { +- // At the moment, we only support refining return types using fully static +- // shape values which serves the current use cases well. +- // Support for partially static shape values is left for future work. +- SmallVector shape; +- if (failed(hlo::matchInts(shapeValue, shape))) +- return rewriter.notifyMatchFailure(op, "expected constant output shape"); +- return refineReturnShape(rewriter, op, shape); +-} +- + struct RefineAllGatherOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AllGatherOp op, +@@ -1115,39 +1065,8 @@ + using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; + + void runOnOperation() override { +- // Only one function per module is supported at the moment to avoid the need +- // to think about iterative type inference algorithms. +- // Current use cases are served well by inlining multiple functions into +- // a single function, so we leave native support for multiple functions to +- // future work. +- // To enable modules that contain CustomCallOp::called_computations, +- // we allow multiple functions, in which case we only refine the main +- // function called "main", assuming that the called computations will have +- // static shapes. Lifting this assumption and expanding refinement to +- // multiple functions is left for future work. +- ModuleOp module = getOperation(); +- auto funcs = llvm::to_vector(module.getOps()); +- if (funcs.empty()) return; +- func::FuncOp func; +- if (funcs.size() == 1) { +- func = funcs[0]; +- } else { +- func = module.lookupSymbol("main"); +- } +- if (!func) { +- module.emitOpError() +- << "must have no more than one function or a `main`" +- << " function to clearly identify which function will be refined"; +- return signalPassFailure(); +- } +- +- // Similarly, only one block per function is supported at the moment. +- // At the StableHLO level, functions are expected to only have one block, +- // so supporting more is out of scope for this pass. +- if (!func.getRegion().hasOneBlock()) { +- func.emitOpError() << "must have exactly one block"; +- return signalPassFailure(); +- } ++ auto func = getStablehloRefineShapesTarget(getOperation()); ++ if (!func) return signalPassFailure(); + + // The algorithm behind this pass consists of a single traversal of the + // function. This is sufficient because we only support one function per +@@ -1163,44 +1082,7 @@ + config.strictMode = GreedyRewriteStrictness::AnyOp; + + RewritePatternSet patterns(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); +- patterns.add(&getContext()); ++ populateStablehloRefineShapesPatterns(&patterns, &getContext()); + if (failed( + applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { + return signalPassFailure(); +@@ -1209,5 +1091,86 @@ + }; + + } // namespace ++ ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) { ++ // Only one function per module is supported at the moment to avoid the need ++ // to think about iterative type inference algorithms. ++ // Current use cases are served well by inlining multiple functions into ++ // a single function, so we leave native support for multiple functions to ++ // future work. ++ // To enable modules that contain CustomCallOp::called_computations, ++ // we allow multiple functions, in which case we only refine the main ++ // function called "main", assuming that the called computations will have ++ // static shapes. Lifting this assumption and expanding refinement to ++ // multiple functions is left for future work. ++ auto funcs = llvm::to_vector(module.getOps()); ++ if (funcs.empty()) return nullptr; ++ ++ func::FuncOp result; ++ if (funcs.size() == 1) { ++ result = funcs[0]; ++ } else { ++ result = module.lookupSymbol("main"); ++ } ++ if (!result) { ++ module.emitOpError() ++ << "must have no more than one function or a `main`" ++ << " function to clearly identify which function will be refined"; ++ return nullptr; ++ } ++ ++ // Similarly, only one block per function is supported at the moment. ++ // At the StableHLO level, functions are expected to only have one block, ++ // so supporting more is out of scope for this pass. ++ if (!result.getRegion().hasOneBlock()) { ++ result.emitOpError() << "must have exactly one block"; ++ return nullptr; ++ } ++ ++ return result; ++} ++ ++void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns, ++ MLIRContext* context) { ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++ patterns->add(context); ++} ++ + } // namespace stablehlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h +@@ -0,0 +1,102 @@ ++/* Copyright 2022 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H ++ ++#include "llvm/ADT/SmallVector.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/BuiltinOps.h" ++#include "mlir/IR/Operation.h" ++#include "mlir/IR/PatternMatch.h" ++#include "mlir/IR/Types.h" ++#include "mlir/IR/Value.h" ++#include "mlir/Interfaces/InferTypeOpInterface.h" ++#include "mlir/Support/LogicalResult.h" ++#include "stablehlo/dialect/Base.h" ++ ++namespace mlir { ++namespace stablehlo { ++ ++// Gets a FuncOp that --stablehlo-refine-shapes will run on. ++// Returns a nullptr and emits appropriate errors if such a function cannot ++// be obtained from the module. ++func::FuncOp getStablehloRefineShapesTarget(ModuleOp module); ++ ++// Refines the values using the given types. ++// Tricky implementation details: ++// 1) Need to support partial shape refinements, e.g. if just a single ++// dimension size out of an entire tensor type got refined. This is done ++// via inferMostSpecificType. ++// 2) Need to signal propagation of the refined shapes across the ++// StableHLO program. Different callers of this function have different ++// propagation needs, so this function doesn't signal anything on its own ++// and leaves that to the callers. ++LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, ++ ValueRange values, TypeRange types); ++ ++// Refines the return types of the given operation using the given types. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef types); ++ ++// Refines the return types of the given operation using the given types. ++// Tricky implementation details: ++// 1) `types` can include non-shaped types. If there are tuple types, ++// then they are first flattened into non-tuple types using in-order ++// traversal, and only then we apply the refinements. If there are other ++// types, then the corresponding refinements must be completely empty. ++// 2) Encodings are not supported. In principle, TypeExtensions should be ++// supportable, but this needs careful thinking through. Given that no one ++// asked for support for bounded dynamism in this pass yet, this is left ++// for future work. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++LogicalResult refineReturnTypes(PatternRewriter& rewriter, Operation* op, ++ ArrayRef refinements); ++ ++// Refines the return type of the given operation using the given shape. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++template ++LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, ++ ArrayRef shape) { ++ return refineReturnTypes(rewriter, op, ShapedTypeComponents(shape)); ++} ++ ++// Refines the return type of the given operation using the given shape. ++// This function also signals PatternRewriter that it needs to visit all the ++// users of this op if any updates to its results have happened during execution ++// of the function. ++template ++LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op, ++ Value shapeValue) { ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ SmallVector shape; ++ if (failed(hlo::matchInts(shapeValue, shape))) ++ return rewriter.notifyMatchFailure(op, "expected constant output shape"); ++ return refineReturnShape(rewriter, op, shape); ++} ++ ++} // namespace stablehlo ++} // namespace mlir ++ ++#endif // STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H +diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +@@ -430,9 +430,20 @@ + SmallVector& stablehloAttrs) { + auto tensorAttr = dyn_cast(vhloAttr); + if (!tensorAttr) return specialFailure(); +- ArrayRef data( +- reinterpret_cast(tensorAttr.getData().data()), +- tensorAttr.getData().size() / sizeof(int64_t)); ++ ++ auto data = ArrayRef( ++ reinterpret_cast(tensorAttr.getData().data()), ++ tensorAttr.getData().size() / sizeof(int64_t)) ++ .vec(); ++ ++ // Handle splats ++ if (data.size() == 1) { ++ auto tensorType = tensorAttr.getType().dyn_cast(); ++ if (!tensorType || (tensorType.getShape().size() != 1)) ++ return specialFailure(); ++ auto size = tensorType.getShape()[0]; ++ data.resize(size, data[0]); ++ } + + stablehloAttrs.emplace_back( + vhloName, DenseI64ArrayAttr::get(vhloAttr.getContext(), data)); diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 8d7054dda8b2c0..f175093e925b74 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "04291aea6b50d9573e6f4de184938d83b9564cd0" - STABLEHLO_SHA256 = "2f57b2cb8eeadebe8430e294f88919b392cf472c62fdd40d4713680b283d64e5" + STABLEHLO_COMMIT = "ab709fe48de88c67717abfbd7ef17425eb95ddaf" + STABLEHLO_SHA256 = "a469ecc3d6747f9effdc1c7813568953dd1dc30070ca8f4f6f8a4d405e8c687e" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/triton/cl577369732.patch b/third_party/xla/third_party/triton/cl577369732.patch deleted file mode 100644 index e63b9f3804974b..00000000000000 --- a/third_party/xla/third_party/triton/cl577369732.patch +++ /dev/null @@ -1,116 +0,0 @@ -==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#19 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -759,7 +759,7 @@ - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { -- OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &operand = *forOp.getTiedLoopInit(arg); - setValueMapping(arg, operand.get(), 0); - } - -==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#10 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -188,7 +188,7 @@ - auto getIncomingOp = [this](Value v) -> Value { - if (auto arg = v.dyn_cast()) - if (arg.getOwner()->getParentOp() == forOp.getOperation()) -- return forOp.getOpOperandForRegionIterArg(arg).get(); -+ return forOp.getTiedLoopInit(arg)->get(); - return Value(); - }; - -@@ -298,10 +298,10 @@ - Operation *firstDot = builder.clone(*dot, mapping); - if (Value a = operand2headPrefetch.lookup(dot.getA())) - firstDot->setOperand( -- 0, newForOp.getRegionIterArgForOpOperand(*a.use_begin())); -+ 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); - if (Value b = operand2headPrefetch.lookup(dot.getB())) - firstDot->setOperand( -- 1, newForOp.getRegionIterArgForOpOperand(*b.use_begin())); -+ 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); - - // remaining part - int64_t kOff = prefetchWidth; -==== triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#18 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -245,7 +245,7 @@ - for (OpOperand &use : value.getUses()) { - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { -- Value arg = forOp.getRegionIterArgForOpOperand(use); -+ Value arg = forOp.getTiedLoopRegionIterArg(&use); - Value result = forOp.getResultForOpOperand(use); - setEncoding({arg, result}, info, changed, user); - continue; -@@ -767,7 +767,7 @@ - SmallVector newOperands; - for (auto arg : forOp.getRegionIterArgs()) { - if (slice.count(arg)) { -- OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &initVal = *forOp.getTiedLoopInit(arg); - argMapping.push_back(std::make_pair( - forOp.getResultForOpOperand(initVal).getResultNumber(), - forOp.getInitArgs().size() + newOperands.size())); -==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#16 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -430,10 +430,10 @@ - Block *block = blockArg.getOwner(); - Operation *parentOp = block->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { -- OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg); -+ OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); - Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( - blockArg.getArgNumber() - forOp.getNumInductionVars()); -- queue.push_back({initOperand.get(), encoding}); -+ queue.push_back({initOperand->get(), encoding}); - queue.push_back({yieldOperand, encoding}); - continue; - } -==== triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp#1 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -88,9 +88,8 @@ - auto parentOp = blockArg.getOwner()->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- if (failed(getDependentPointers( -- forOp.getOpOperandForRegionIterArg(blockArg).get(), -- dependentSet, processedSet))) -+ if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(), -+ dependentSet, processedSet))) - return failure(); - - unsigned operandIdx = -@@ -383,7 +382,7 @@ - if (failed(addControlOperandsForForOp(forOp))) - return failure(); - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get(); -+ Value operand = forOp.getTiedLoopInit(blockArg)->get(); - if (failed(tryInsertAndPropagate(operand))) - return failure(); - -==== triton/test/lib/Analysis/TestAlias.cpp#5 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/test/lib/Analysis/TestAlias.cpp ==== -# action=edit type=text ---- triton/test/lib/Analysis/TestAlias.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/test/lib/Analysis/TestAlias.cpp 2023-10-27 20:17:47.000000000 -0700 -@@ -87,7 +87,7 @@ - } - if (auto forOp = dyn_cast(op)) { - for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { -- auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); -+ auto operand = forOp.getTiedLoopInit(arg.value())->get(); - auto opNames = getAllocOpNames(operand); - auto argName = getValueOperandName(arg.value(), state); - print(argName, opNames, os); diff --git a/third_party/xla/third_party/triton/cl577379396.patch b/third_party/xla/third_party/triton/cl577379396.patch deleted file mode 100644 index ee569f9b8f55c3..00000000000000 --- a/third_party/xla/third_party/triton/cl577379396.patch +++ /dev/null @@ -1,33 +0,0 @@ -diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ---- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp -@@ -246,7 +246,7 @@ SmallVector LayoutPropagation::pr - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { - Value arg = forOp.getTiedLoopRegionIterArg(&use); -- Value result = forOp.getResultForOpOperand(use); -+ Value result = forOp.getTiedLoopResult(&use); - setEncoding({arg, result}, info, changed, user); - continue; - } -@@ -769,7 +769,7 @@ static void rewriteSlice(SetVector()) { - auto result = value.cast(); -- OpOperand &forOperand = nestedFor.getOpOperandForResult(result); -+ OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); - markLive(forOperand.get()); - auto nestedYieldOp = - cast(nestedFor.getBody()->getTerminator()); diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index c0c6207f85da73..b864617b503f3e 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl578837341" - TRITON_SHA256 = "0d8112bb31d48b5beadbfc2e13c52770a95d3759b312b15cf26dd72e71410568" + TRITON_COMMIT = "cl588045313" + TRITON_SHA256 = "14cb6ddccc3139b2e8d77af08bb232eb06536d5c715c4bbc720a752af40ba2dc" tf_http_archive( name = "triton", @@ -15,7 +15,7 @@ def repo(): urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. patch_file = [ - "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", + "//third_party/triton:cl568176943.patch", ], ) diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index e9fc2d4eb20a55..9de6b6e0c2bd54 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -526,34 +526,9 @@ build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.16-clang_c build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_config_nccl" test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=cuda +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true -build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true -build:rbe_linux_cuda_nvcc --define=xla_python_enable_gpu=true -build:rbe_linux_cuda_nvcc --config=tensorrt -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" -build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" -build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" -build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_python3.9" -build:rbe_linux_cuda_nvcc --python_path="/usr/bin/python3" -# These you may need to change for your own GCP project. -common:rbe_linux_cuda_nvcc --remote_instance_name=projects/tensorflow-testing/instances/default_instance -build:rbe_linux_cuda_nvcc --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda12.2-cudnn8.9_config_nccl" -test:rbe_linux_cuda_nvcc --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -692,19 +667,39 @@ build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda build:release_cpu_macos --config=avx_linux test:release_cpu_macos --config=release_base -# Build configs for macOS ARM CPUs +# Base build configs for macOS +build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer +build:release_macos_base --define=no_nccl_support=true --output_filter=^$ + +# Build configs for macOS x86 +build:release_macos_x86 --config=release_macos_base +# Build with the AVX instruction set when on macOS x86 +build:release_macos_x86 --config=avx_linux +build:release_macos_x86 --cpu=darwin +# Target Catalina as the minimum compatible OS version +build:release_macos_x86 --macos_minimum_os=10.15 +build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 + +# Build configs for macOS Arm64 +build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 -# Set DEVELOPER_DIR to select a version of Xcode. -build:release_macos_arm64 --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer -build:release_macos_arm64 --define=no_nccl_support=true -# Suppress all warning messages -build:release_macos_arm64 --output_filter=^$ -# Disable MKL build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0 # Target Moneterey as the minimum compatible OS version build:release_macos_arm64 --macos_minimum_os=12.0 build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0 +# Base test configs for macOS +test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS +test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors +test:release_macos_base --build_tests_only --keep_going +test:release_macos_base --flaky_test_attempts=3 + +# Test configs for macOS x86 +test:release_macos_x86 --config=release_macos_base + +# Test configs for macOS Arm64 +test:release_macos_arm64 --config=release_macos_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true @@ -723,10 +718,14 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs -# push to the cache. +# push to the cache. For macOS, use --config=tf_public_macos_cache build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials +# Public cache for macOS builds +build:tf_public_macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to TF's CI system. +build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_local_results=true --google_default_credentials # END TF CACHE HELPER OPTIONS # BEGIN TF TEST SUITE OPTIONS @@ -743,22 +742,27 @@ build:linux_libtensorflow_build -- //tensorflow/tools/lib_package:libtensorflow. test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --test_lang_filters=py -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# MACOS X86 WHEEL +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -766,21 +770,53 @@ test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorf test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA PYCPP: test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +# CROSS-COMPILE ARM64 PYCPP +test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +# Tests that fail only when cross-compiled +test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # END TF TEST SUITE OPTIONS + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_linux_arm64 --host_cpu=k8 +build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE configs +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_base +build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local +test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors +test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only +# END LINUX AARCH64 CROSS-COMPILE CONFIGS diff --git a/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh b/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh index 331efa186fb87e..4f4b0a0fdf9d31 100644 --- a/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh +++ b/third_party/xla/third_party/tsl/.kokoro/windows/windows_build.sh @@ -50,7 +50,7 @@ export PATH="$PATH:/c/Python38" -- //tsl/... \ || { echo "Bazel Build Failed" && exit 1; } -# Test TSL TODO(ddunleavy) enable all tests +# Test TSL /c/tools/bazel.exe test \ --output_filter="" \ --flaky_test_attempts=3 \ @@ -60,7 +60,7 @@ export PATH="$PATH:/c/Python38" --build_tag_filters=$TAGS_FILTER \ --test_tag_filters=$TAGS_FILTER \ --keep_going \ - -- //tsl/... -//tsl/platform:subprocess_test -//tsl/platform/cloud:google_auth_provider_test -//tsl/platform/cloud:oauth_client_test \ + -- //tsl/... \ || { echo "Bazel Test Failed" && exit 1; } exit 0 diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index e4974e79805725..fa84f35768a5d2 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -29,7 +29,9 @@ third_party/gpus/crosstool/BUILD: third_party/gpus/crosstool/LICENSE: third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl: third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl: +third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl: third_party/gpus/cuda/BUILD.tpl: +third_party/gpus/cuda/BUILD.windows.tpl: third_party/gpus/cuda/BUILD: third_party/gpus/cuda/LICENSE: third_party/gpus/cuda/build_defs.bzl.tpl: @@ -129,6 +131,8 @@ tools/toolchains/BUILD: tools/toolchains/clang6/BUILD: tools/toolchains/cpus/py/BUILD: tools/toolchains/cpus/py3/BUILD: +tools/toolchains/cross_compile/cc/BUILD: +tools/toolchains/cross_compile/config/BUILD: tools/toolchains/embedded/arm-linux/BUILD: tools/toolchains/java/BUILD: tools/toolchains/python/BUILD: diff --git a/third_party/xla/third_party/tsl/third_party/gemmlowp/workspace.bzl b/third_party/xla/third_party/tsl/third_party/gemmlowp/workspace.bzl index b98035569852e2..884f707719a623 100644 --- a/third_party/xla/third_party/tsl/third_party/gemmlowp/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/gemmlowp/workspace.bzl @@ -7,8 +7,8 @@ def repo(): # Attention: tools parse and update these lines. # LINT.IfChange - GEMMLOWP_COMMIT = "e844ffd17118c1e17d94e1ba4354c075a4577b88" - GEMMLOWP_SHA256 = "522b7a82d920ebd0c4408a5365866a40b81d1c0d60b2369011d315cca03c6476" + GEMMLOWP_COMMIT = "16e8662c34917be0065110bfcd9cc27d30f52fdf" + GEMMLOWP_SHA256 = "7dc418717c8456473fac4ff2288b71057e3dcb72894524c734a4362cdb51fa8b" # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/gemmlowp.cmake) tf_http_archive( diff --git a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py index b7d98ef2581157..afd6380b0ac203 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py @@ -23,6 +23,7 @@ """ import os import os.path +import platform import subprocess import sys @@ -38,6 +39,10 @@ class ConfigError(Exception): pass +def _is_windows(): + return platform.system() == "Windows" + + def check_cuda_lib(path, check_soname=True): """Tests if a library exists on disk and whether its soname matches the filename. @@ -52,7 +57,7 @@ def check_cuda_lib(path, check_soname=True): if not os.path.isfile(path): raise ConfigError("No library found under: " + path) objdump = which("objdump") - if check_soname and objdump is not None: + if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") output = [line for line in output.splitlines() if "SONAME" in line] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 81e54ad431fccf..0da1d7b58f4bb0 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -45,10 +45,11 @@ import pipes # Template values set by cuda_autoconf. CPU_COMPILER = ('%{cpu_compiler}') -GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') +HOST_COMPILER_PATH = ('%{host_compiler_path}') NVCC_PATH = '%{nvcc_path}' -PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) +PREFIX_DIR = os.path.dirname(HOST_COMPILER_PATH) +USE_CLANG_COMPILER = '%{use_clang_compiler}' NVCC_VERSION = '%{cuda_version}' def Log(s): @@ -253,13 +254,23 @@ def InvokeNvcc(argv, log=False): # Force C++17 dialect (note, everything in just one string!) nvccopts += ' --std c++17 ' nvccopts += fatbin_options + # The option `-allow-unsupported-compiler` is required for the combination of + # NVCC+clang compilers. + # The following message appears if this option is not provided: + # unsupported clang version! clang version must be less than 16 and greater + # than 3.2 . The nvcc flag '-allow-unsupported-compiler' can be used + # to override this version check; however, using an unsupported host compiler + # may cause compilation failure or incorrect run time execution. + # Use at your own risk. + if USE_CLANG_COMPILER: + nvccopts += ' -allow-unsupported-compiler --expt-extended-lambda --expt-relaxed-constexpr ' if depfiles: # Generate the dependency file depfile = depfiles[0] cmd = (NVCC_PATH + ' ' + nvccopts + ' --compiler-options "' + host_compiler_options + '"' + - ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' --compiler-bindir=' + HOST_COMPILER_PATH + ' -I .' + ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile) if log: Log(cmd) @@ -269,7 +280,7 @@ def InvokeNvcc(argv, log=False): cmd = (NVCC_PATH + ' ' + nvccopts + ' --compiler-options "' + host_compiler_options + ' -fPIC"' + - ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' --compiler-bindir=' + HOST_COMPILER_PATH + ' -I .' + ' -x cu ' + opt + includes + ' -c ' + srcs + out) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 8fb22313010a45..77ec948af32c6e 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -86,8 +86,8 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - #if args.fno_canonical_system_headers: - # opts += ' -fno-canonical-system-headers' + if args.fno_canonical_system_headers: + opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl new file mode 100644 index 00000000000000..c46e09484fdfad --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows. + +DESCRIPTION: + This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc +""" + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import tempfile + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('%{cpu_compiler}') +GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') + +NVCC_PATH = '%{nvcc_path}' +NVCC_VERSION = '%{cuda_version}' +NVCC_TEMP_DIR = "%{nvcc_tmp_dir}" + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from options. + + Args: + option: The option whose value to extract. + + Returns: + 1. A list of values, either directly following the option, + (eg., /opt val1 val2) or values collected from multiple occurrences of + the option (eg., /opt val1 /opt val2). + 2. The leftover options. + """ + + parser = ArgumentParser(prefix_chars='-/') + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-/').replace('-', '_') + args, leftover = parser.parse_known_args(argv) + if args and vars(args)[option]: + return (sum(vars(args)[option], []), leftover) + return ([], leftover) + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + 1. The string that can be passed directly to nvcc. + 2. The leftover options. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, leftover = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return (['--' + a for a in options], leftover) + return ([], leftover) + + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling os.system('nvcc ' + args) + """ + + src_files = [f for f in argv if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + if len(src_files) == 0: + raise Error('No source files found for cuda compilation.') + + out_file = [ f for f in argv if f.startswith('/Fo') ] + if len(out_file) != 1: + raise Error('Please specify exactly one output file for cuda compilation.') + out = ['-o', out_file[0][len('/Fo'):]] + + nvcc_compiler_options, argv = GetNvccOptions(argv) + + opt_option, argv = GetOptionValue(argv, '/O') + opt = ['-g'] + if (len(opt_option) > 0 and opt_option[0] != 'd'): + opt = ['-O2'] + + include_options, argv = GetOptionValue(argv, '/I') + includes = ["-I " + include for include in include_options] + + defines, argv = GetOptionValue(argv, '/D') + defines = [ + '-D' + define + for define in defines + if 'BAZEL_CURRENT_REPOSITORY' not in define + ] + + undefines, argv = GetOptionValue(argv, '/U') + undefines = ['-U' + define for define in undefines] + + fatbin_options, argv = GetOptionValue(argv, '-Xcuda-fatbinary') + fatbin_options = ['--fatbin-options=' + option for option in fatbin_options] + + # The rest of the unrecognized options should be passed to host compiler + host_compiler_options = [option for option in argv if option not in (src_files + out_file)] + + m_options = ["-m64"] + + nvccopts = ['-D_FORCE_INLINES'] + compute_capabilities, argv = GetOptionValue(argv, "--cuda-gpu-arch") + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=sm_%s"' % (capability, capability) + ] + compute_capabilities, argv = GetOptionValue(argv, '--cuda-include-ptx') + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=compute_%s"' % (capability, capability) + ] + _, argv = GetOptionValue(argv, '--no-cuda-include-ptx') + + # nvcc doesn't respect the INCLUDE and LIB env vars from MSVC, + # so we explicity specify the system include paths and library search paths. + if 'INCLUDE' in os.environ: + nvccopts += [('--system-include="%s"' % p) for p in os.environ['INCLUDE'].split(";")] + if 'LIB' in os.environ: + nvccopts += [('--library-path="%s"' % p) for p in os.environ['LIB'].split(";")] + + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += m_options + nvccopts += fatbin_options + nvccopts += ['--compiler-options=' + ",".join(host_compiler_options)] + nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files + # Specify a unique temp directory for nvcc to generate intermediate files, + # then Bazel can ignore files under NVCC_TEMP_DIR during dependency check + # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver + # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists. + if os.path.isfile(NVCC_TEMP_DIR): + os.remove(NVCC_TEMP_DIR) + if not os.path.exists(NVCC_TEMP_DIR): + os.makedirs(NVCC_TEMP_DIR) + # Provide a unique dir for each compiling action to avoid conflicts. + tempdir = tempfile.mkdtemp(dir = NVCC_TEMP_DIR) + nvccopts += ['--keep', '--keep-dir', tempdir] + # Force C++17 dialect (note, everything in just one string!) + nvccopts += ['--std c++17'] + if log: + Log([NVCC_PATH] + nvccopts) + + # Store command line options in a file to avoid hitting the character limit. + optsfile = tempfile.NamedTemporaryFile(mode='w', dir=tempdir, delete=False) + optsfile.write("\n".join(nvccopts)) + optsfile.close() + + proc = subprocess.Popen([NVCC_PATH, "--options-file", optsfile.name], + stdout=sys.stdout, + stderr=sys.stderr, + env=os.environ.copy(), + shell=True) + proc.wait() + return proc.returncode + +def ExpandParamsFileForArgv(): + new_argv = [] + for arg in sys.argv: + if arg.startswith("@"): + with open(arg.strip("@")) as f: + new_argv.extend([l.strip() for l in f.readlines()]) + else: + new_argv.append(arg) + + sys.argv = new_argv + +def ProcessFlagForCommandFile(flag): + if flag.startswith("/D") or flag.startswith("-D"): + # We need to re-escape /DFOO="BAR" as /DFOO=\"BAR\", so that we get + # `#define FOO "BAR"` after expansion as a string literal define + if flag.endswith('"') and not flag.endswith('\\"'): + flag = '\\"'.join(flag.split('"', 1)) + flag = '\\"'.join(flag.rsplit('"', 1)) + return flag + return flag + +def main(): + ExpandParamsFileForArgv() + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log')) + and not flag.startswith(('-nvcc_options'))] + output = [flag for flag in cpu_compiler_flags if flag.startswith("/Fo")] + + # Store command line options in a file to avoid hitting the character limit. + if len(output) == 1: + commandfile_path = output[0][3:] + ".msvc_params" + commandfile = open(commandfile_path, "w") + cpu_compiler_flags = [ProcessFlagForCommandFile(flag) for flag in cpu_compiler_flags] + commandfile.write("\n".join(cpu_compiler_flags)) + commandfile.close() + return subprocess.call([CPU_COMPILER, "@" + commandfile_path]) + else: + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl index 700e040a88eeca..90a18b90de048c 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl @@ -61,23 +61,23 @@ cuda_header_library( cc_library( name = "cudart_static", - srcs = ["cuda/lib/libcudart_static.a"], + srcs = ["cuda/lib/%{cudart_static_lib}"], linkopts = [ "-ldl", - "-lrt", "-lpthread", + %{cudart_static_linkopt} ], ) cc_library( name = "cuda_driver", - srcs = ["cuda/lib/libcuda.so"], + srcs = ["cuda/lib/%{cuda_driver_lib}"], ) cc_library( name = "cudart", - srcs = glob(["cuda/lib/libcudart.so.*"]), - data = glob(["cuda/lib/libcudart.so.*"]), + srcs = ["cuda/lib/%{cudart_lib}"], + data = ["cuda/lib/%{cudart_lib}"], linkstatic = 1, ) @@ -128,30 +128,30 @@ cuda_header_library( cc_library( name = "cublas", - srcs = glob(["cuda/lib/libcublas.so.*"]), - data = glob(["cuda/lib/libcublas.so.*"]), + srcs = ["cuda/lib/%{cublas_lib}"], + data = ["cuda/lib/%{cublas_lib}"], linkstatic = 1, ) cc_library( name = "cublasLt", - srcs = glob(["cuda/lib/libcublasLt.so.*"]), - data = glob(["cuda/lib/libcublasLt.so.*"]), + srcs = ["cuda/lib/%{cublasLt_lib}"], + data = ["cuda/lib/%{cublasLt_lib}"], linkstatic = 1, ) cc_library( name = "cusolver", - srcs = glob(["cuda/lib/libcusolver.so.*"]), - data = glob(["cuda/lib/libcusolver.so.*"]), + srcs = ["cuda/lib/%{cusolver_lib}"], + data = ["cuda/lib/%{cusolver_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) cc_library( name = "cudnn", - srcs = glob(["cuda/lib/libcudnn.so.*"]), - data = glob(["cuda/lib/libcudnn.so.*"]), + srcs = ["cuda/lib/%{cudnn_lib}"], + data = ["cuda/lib/%{cudnn_lib}"], linkstatic = 1, ) @@ -165,15 +165,15 @@ cc_library( cc_library( name = "cufft", - srcs = glob(["cuda/lib/libcufft.so.*"]), - data = glob(["cuda/lib/libcufft.so.*"]), + srcs = ["cuda/lib/%{cufft_lib}"], + data = ["cuda/lib/%{cufft_lib}"], linkstatic = 1, ) cc_library( name = "curand", - srcs = glob(["cuda/lib/libcurand.so.*"]), - data = glob(["cuda/lib/libcurand.so.*"]), + srcs = ["cuda/lib/%{curand_lib}"], + data = ["cuda/lib/%{curand_lib}"], linkstatic = 1, ) @@ -192,7 +192,7 @@ cc_library( alias( name = "cub_headers", - actual = ":cuda_headers", + actual = "%{cub_actual}", ) cuda_header_library( @@ -213,13 +213,13 @@ cuda_header_library( cc_library( name = "cupti_dsos", - data = glob(["cuda/lib/libcupti.so.*"]), + data = ["cuda/lib/%{cupti_lib}"], ) cc_library( name = "cusparse", - srcs = glob(["cuda/lib/libcusparse.so.*"]), - data = glob(["cuda/lib/libcusparse.so.*"]), + srcs = ["cuda/lib/%{cusparse_lib}"], + data = ["cuda/lib/%{cusparse_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl new file mode 100644 index 00000000000000..dee0e898d9ae7a --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl @@ -0,0 +1,238 @@ +load(":build_defs.bzl", "cuda_header_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cuda_header_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-include", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + "cuda/include", + ], +) + +cc_import( + name = "cudart_static", + # /WHOLEARCHIVE:cudart_static.lib will cause a + # "Internal error during CImplib::EmitThunk" error. + # Treat this library as interface library to avoid being whole archived when + # linking a DLL that depends on this. + # TODO(pcloudy): Remove this rule after b/111278841 is resolved. + interface_library = "cuda/lib/%{cudart_static_lib}", + system_provided = 1, +) + +cc_import( + name = "cuda_driver", + interface_library = "cuda/lib/%{cuda_driver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudart", + interface_library = "cuda/lib/%{cudart_lib}", + system_provided = 1, +) + +cuda_header_library( + name = "cublas_headers", + hdrs = [":cublas-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cublas/include"], + strip_include_prefix = "cublas/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusolver_headers", + hdrs = [":cusolver-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusolver/include"], + strip_include_prefix = "cusolver/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cufft_headers", + hdrs = [":cufft-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cufft/include"], + strip_include_prefix = "cufft/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusparse_headers", + hdrs = [":cusparse-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusparse/include"], + strip_include_prefix = "cusparse/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "curand_headers", + hdrs = [":curand-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["curand/include"], + strip_include_prefix = "curand/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cublas", + interface_library = "cuda/lib/%{cublas_lib}", + system_provided = 1, +) + +cc_import( + name = "cublasLt", + interface_library = "cuda/lib/%{cublasLt_lib}", + system_provided = 1, +) + +cc_import( + name = "cusolver", + interface_library = "cuda/lib/%{cusolver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudnn", + interface_library = "cuda/lib/%{cudnn_lib}", + system_provided = 1, +) + +cc_library( + name = "cudnn_header", + hdrs = [":cudnn-include"], + include_prefix = "third_party/gpus/cudnn", + strip_include_prefix = "cudnn/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cufft", + interface_library = "cuda/lib/%{cufft_lib}", + system_provided = 1, +) + +cc_import( + name = "curand", + interface_library = "cuda/lib/%{curand_lib}", + system_provided = 1, +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = "%{cub_actual}", +) + +cuda_header_library( + name = "cupti_headers", + hdrs = [":cuda-extras"], + include_prefix = "third_party/gpus", + includes = ["cuda/extras/CUPTI/include/"], + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "nvml_headers", + hdrs = [":nvml"], + include_prefix = "third_party/gpus", + includes = ["cuda/nvml/include/"], + deps = [":cuda_headers"], +) + +cc_import( + name = "cupti_dsos", + interface_library = "cuda/lib/%{cupti_lib}", + system_provided = 1, +) + +cc_import( + name = "cusparse", + interface_library = "cuda/lib/%{cusparse_lib}", + system_provided = 1, +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +%{copy_rules} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl index 8a0d9eb0872911..ff2f2f41091fe8 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -4,7 +4,8 @@ * `TF_NEED_CUDA`: Whether to enable building with CUDA. * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path - * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler. + * `TF_CUDA_CLANG`: Whether to use clang for C++ and Cuda compilation. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for both host and device code compilation if TF_CUDA_CLANG is 1. * `TF_SYSROOT`: The sysroot to use when compiling. @@ -26,14 +27,27 @@ """ load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") +load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", + "get_env_var", +) +load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", + "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", "err_out", "execute", "get_bash_bin", + "get_cpu_value", "get_host_environ", "get_python_bin", + "is_windows", "raw_exec", "read_dir", "realpath", @@ -82,7 +96,16 @@ def verify_build_defines(params): "host_compiler_warnings", "linker_bin_path", "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", "unfiltered_compile_flags", + "win_compiler_deps", ]: if ("%{" + param + "}") not in params: missing.append(param) @@ -96,13 +119,104 @@ def verify_build_defines(params): ".", ) +def _get_nvcc_tmp_dir_for_windows(repository_ctx): + """Return the Windows tmp directory for nvcc to generate intermediate source files.""" + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" + +def _get_msvc_compiler(repository_ctx): + vc_path = find_vc_path(repository_ctx) + return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") + +def _get_win_cuda_defines(repository_ctx): + """Return CROSSTOOL defines for Windows""" + + # If we are not on Windows, return fake vaules for Windows specific fields. + # This ensures the CROSSTOOL file parser is happy. + if not is_windows(repository_ctx): + return { + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + } + + vc_path = find_vc_path(repository_ctx) + if not vc_path: + auto_configure_fail( + "Visual C++ build tools not found on your machine." + + "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using", + ) + return {} + + env = setup_vc_env_vars(repository_ctx, vc_path) + escaped_paths = escape_string(env["PATH"]) + escaped_include_paths = escape_string(env["INCLUDE"]) + escaped_lib_paths = escape_string(env["LIB"]) + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + + msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat" + msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( + "\\", + "/", + ) + msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( + "\\", + "/", + ) + msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( + "\\", + "/", + ) + + # nvcc will generate some temporary source files under %{nvcc_tmp_dir} + # The generated files are guaranteed to have unique name, so they can share + # the same tmp directory + escaped_cxx_include_directories = [ + _get_nvcc_tmp_dir_for_windows(repository_ctx), + "C:\\\\botcode\\\\w", + ] + for path in escaped_include_paths.split(";"): + if path: + escaped_cxx_include_directories.append(path) + + return { + "%{msvc_env_tmp}": escaped_tmp_dir, + "%{msvc_env_path}": escaped_paths, + "%{msvc_env_include}": escaped_include_paths, + "%{msvc_env_lib}": escaped_lib_paths, + "%{msvc_cl_path}": msvc_cl_path, + "%{msvc_ml_path}": msvc_ml_path, + "%{msvc_link_path}": msvc_link_path, + "%{msvc_lib_path}": msvc_lib_path, + "%{cxx_builtin_include_directories}": to_list_of_strings( + escaped_cxx_include_directories, + ), + } + # TODO(dzc): Once these functions have been factored out of Bazel's # cc_configure.bzl, load them from @bazel_tools instead. # BEGIN cc_configure common functions. -def find_cc(repository_ctx): +def find_cc(repository_ctx, use_cuda_clang): """Find the C++ compiler.""" + if is_windows(repository_ctx): + return _get_msvc_compiler(repository_ctx) - if _use_cuda_clang(repository_ctx): + if use_cuda_clang: target_cc_name = "clang" cc_path_envvar = _CLANG_CUDA_COMPILER_PATH if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG): @@ -251,9 +365,10 @@ def _cuda_include_path(repository_ctx, cuda_config): Returns: A list of the gcc host compiler include directories. """ - nvcc_path = repository_ctx.path( - "%s/bin/nvcc" % cuda_config.cuda_toolkit_path, - ) + nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % ( + cuda_config.cuda_toolkit_path, + ".exe" if cuda_config.cpu_value == "Windows" else "", + )) # The expected exit code of this command is non-zero. Bazel remote execution # only caches commands with zero exit code. So force a zero exit code. @@ -314,6 +429,10 @@ def matches_version(environ_version, detected_version): return False return True +_NVCC_VERSION_PREFIX = "Cuda compilation tools, release " + +_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" + def compute_capabilities(repository_ctx): """Returns a list of strings representing cuda compute capabilities. @@ -356,11 +475,12 @@ def compute_capabilities(repository_ctx): return capabilities -def lib_name(base_name, version = None, static = False): +def lib_name(base_name, cpu_value, version = None, static = False): """Constructs the platform-specific name of a library. Args: base_name: The name of the library, such as "cudart" + cpu_value: The name of the host operating system. version: The version of the library. static: True the library is static or False if it is a shared object. @@ -368,20 +488,29 @@ def lib_name(base_name, version = None, static = False): The platform-specific name of the library. """ version = "" if not version else "." + version - if static: - return "lib%s.a" % base_name - return "lib%s.so%s" % (base_name, version) + if cpu_value in ("Linux", "FreeBSD"): + if static: + return "lib%s.a" % base_name + return "lib%s.so%s" % (base_name, version) + elif cpu_value == "Windows": + return "%s.lib" % base_name + elif cpu_value == "Darwin": + if static: + return "lib%s.a" % base_name + return "lib%s%s.dylib" % (base_name, version) + else: + auto_configure_fail("Invalid cpu_value: %s" % cpu_value) -def _lib_path(lib, basedir, version, static): - file_name = lib_name(lib, version, static) +def _lib_path(lib, cpu_value, basedir, version, static): + file_name = lib_name(lib, cpu_value, version, static) return "%s/%s" % (basedir, file_name) def _should_check_soname(version, static): return version and not static -def _check_cuda_lib_params(lib, basedir, version, static = False): +def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False): return ( - _lib_path(lib, basedir, version, static), + _lib_path(lib, cpu_value, basedir, version, static), _should_check_soname(version, static), ) @@ -401,6 +530,8 @@ def _check_cuda_libs(repository_ctx, script_path, libs): all_paths = [path for path, _ in libs] checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines() + # Filter out empty lines from splitting on '\r\n' on Windows + checked_paths = [path for path in checked_paths if len(path) > 0] if all_paths != checked_paths: auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths)) @@ -418,62 +549,86 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): Returns: Map of library names to structs of filename and path. """ + cpu_value = cuda_config.cpu_value + stub_dir = "" if is_windows(repository_ctx) else "/stubs" + check_cuda_libs_params = { "cuda": _check_cuda_lib_params( "cuda", - cuda_config.config["cuda_library_dir"] + "/stubs", + cpu_value, + cuda_config.config["cuda_library_dir"] + stub_dir, version = None, + static = False, ), "cudart": _check_cuda_lib_params( "cudart", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, + static = False, ), "cudart_static": _check_cuda_lib_params( "cudart_static", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, static = True, ), "cublas": _check_cuda_lib_params( "cublas", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cublasLt": _check_cuda_lib_params( "cublasLt", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cusolver": _check_cuda_lib_params( "cusolver", + cpu_value, cuda_config.config["cusolver_library_dir"], cuda_config.cusolver_version, + static = False, ), "curand": _check_cuda_lib_params( "curand", + cpu_value, cuda_config.config["curand_library_dir"], cuda_config.curand_version, + static = False, ), "cufft": _check_cuda_lib_params( "cufft", + cpu_value, cuda_config.config["cufft_library_dir"], cuda_config.cufft_version, + static = False, ), "cudnn": _check_cuda_lib_params( "cudnn", + cpu_value, cuda_config.config["cudnn_library_dir"], cuda_config.cudnn_version, + static = False, ), "cupti": _check_cuda_lib_params( "cupti", + cpu_value, cuda_config.config["cupti_library_dir"], cuda_config.cupti_version, + static = False, ), "cusparse": _check_cuda_lib_params( "cusparse", + cpu_value, cuda_config.config["cusparse_library_dir"], cuda_config.cusparse_version, + static = False, ), } @@ -483,6 +638,10 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()} return paths +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "" if cpu_value == "Darwin" else "\"-lrt\"," + # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, # and nccl_configure.bzl. def find_cuda_config(repository_ctx, cuda_libraries): @@ -509,34 +668,37 @@ def _get_cuda_config(repository_ctx): cudart_version: The CUDA runtime version on the system. cudnn_version: The version of cuDNN on the system. compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. """ config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) + cpu_value = get_cpu_value(repository_ctx) toolkit_path = config["cuda_toolkit_path"] + is_windows = cpu_value == "Windows" cuda_version = config["cuda_version"].split(".") cuda_major = cuda_version[0] cuda_minor = cuda_version[1] - cuda_version = "%s.%s" % (cuda_major, cuda_minor) - cudnn_version = "%s" % config["cudnn_version"] + cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) + cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] if int(cuda_major) >= 11: # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability. if int(cuda_major) == 11: - cudart_version = "11.0" + cudart_version = "64_110" if is_windows else "11.0" cupti_version = cuda_version else: - cudart_version = "%s" % cuda_major + cudart_version = ("64_%s" if is_windows else "%s") % cuda_major cupti_version = cudart_version - cublas_version = "%s" % config["cublas_version"].split(".")[0] - cusolver_version = "%s" % config["cusolver_version"].split(".")[0] - curand_version = "%s" % config["curand_version"].split(".")[0] - cufft_version = "%s" % config["cufft_version"].split(".")[0] - cusparse_version = "%s" % config["cusparse_version"].split(".")[0] + cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0] + cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0] + curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0] + cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0] + cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0] elif (int(cuda_major), int(cuda_minor)) >= (10, 1): # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. # It changed from 'x.y' to just 'x' in CUDA 10.1. - cuda_lib_version = "%s" % cuda_major + cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major cudart_version = cuda_version cupti_version = cuda_version cublas_version = cuda_lib_version @@ -566,6 +728,7 @@ def _get_cuda_config(repository_ctx): cusparse_version = cusparse_version, cudnn_version = cudnn_version, compute_capabilities = compute_capabilities(repository_ctx), + cpu_value = cpu_value, config = config, ) @@ -611,6 +774,8 @@ error_gpu_disabled() """ def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + # Set up BUILD file for cuda/. _tpl( repository_ctx, @@ -625,6 +790,23 @@ def _create_dummy_repository(repository_ctx): repository_ctx, "cuda:BUILD", { + "%{cuda_driver_lib}": lib_name("cuda", cpu_value), + "%{cudart_static_lib}": lib_name( + "cudart_static", + cpu_value, + static = True, + ), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + "%{cudart_lib}": lib_name("cudart", cpu_value), + "%{cublas_lib}": lib_name("cublas", cpu_value), + "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), + "%{cusolver_lib}": lib_name("cusolver", cpu_value), + "%{cudnn_lib}": lib_name("cudnn", cpu_value), + "%{cufft_lib}": lib_name("cufft", cpu_value), + "%{curand_lib}": lib_name("curand", cpu_value), + "%{cupti_lib}": lib_name("cupti", cpu_value), + "%{cusparse_lib}": lib_name("cusparse", cpu_value), + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": """ filegroup(name="cuda-include") filegroup(name="cublas-include") @@ -643,9 +825,20 @@ filegroup(name="cudnn-include") repository_ctx.file("cuda/cuda/include/cublas.h") repository_ctx.file("cuda/cuda/include/cudnn.h") repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h") - repository_ctx.file("cuda/cuda/lib/libcuda.so") - repository_ctx.file("cuda/cuda/lib/libcudart_static.a") repository_ctx.file("cuda/cuda/nvml/include/nvml.h") + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) + repository_ctx.file( + "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), + ) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value)) # Set up cuda_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. @@ -709,7 +902,7 @@ def make_copy_files_rule(repository_ctx, name, srcs, outs): cmd = \"""%s \""", )""" % (name, "\n".join(outs), " && \\\n".join(cmds)) -def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): +def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None): """Returns a rule to recursively copy a directory. If exceptions is not None, it must be a list of files or directories in 'src_dir'; these will be excluded from copying. @@ -717,25 +910,39 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): src_dir = _norm_path(src_dir) out_dir = _norm_path(out_dir) outs = read_dir(repository_ctx, src_dir) + post_cmd = "" + if exceptions != None: + outs = [x for x in outs if not any([ + x.startswith(src_dir + "/" + y) + for y in exceptions + ])] outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] # '@D' already contains the relative path for a single file, see # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" + if exceptions != None: + for x in exceptions: + post_cmd += " ; rm -fR " + out_dir + "/" + x return """genrule( name = "%s", outs = [ %s ], - cmd = \"""cp -rLf "%s/." "%s/" \""", -)""" % (name, "\n".join(outs), src_dir, out_dir) + cmd = \"""cp -rLf "%s/." "%s/" %s\""", +)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd) def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" def _use_cuda_clang(repository_ctx): + # Returns the flag if we need to use clang both for C++ and Cuda. return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, "TF_NVCC_CLANG") + def _tf_sysroot(repository_ctx): return get_host_environ(repository_ctx, _TF_SYSROOT, "") @@ -752,6 +959,22 @@ def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): def _tpl_path(repository_ctx, filename): return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename)) +def _basename(repository_ctx, path_str): + """Returns the basename of a path of type string. + + This method is different from path.basename in that it also works if + the host platform is different from the execution platform + i.e. linux -> windows. + """ + + num_chars = len(path_str) + is_win = is_windows(repository_ctx) + for i in range(num_chars): + r_i = num_chars - 1 - i + if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/": + return path_str[r_i + 1:] + return path_str + def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" @@ -760,14 +983,15 @@ def _create_local_cuda_repository(repository_ctx): # can easily lead to a O(n^2) runtime in the number of labels. # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [ - "cuda:BUILD", "cuda:build_defs.bzl", "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", + "crosstool:windows/msvc_wrapper_for_nvcc.py", "crosstool:BUILD", "crosstool:cc_toolchain_config.bzl", "cuda:cuda_config.h", "cuda:cuda_config.py", ]} + tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD") cuda_config = _get_cuda_config(repository_ctx) @@ -879,7 +1103,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_lib_outs = [] for path in cuda_libs.values(): cuda_lib_srcs.append(path) - cuda_lib_outs.append("cuda/lib/" + path.rpartition("/")[-1]) + cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path)) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-lib", @@ -888,7 +1112,11 @@ def _create_local_cuda_repository(repository_ctx): )) # copy files mentioned in third_party/nccl/build_defs.bzl.tpl - bin_files = ["crt/link.stub", "bin2c", "fatbinary", "nvlink", "nvprune"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + bin_files = ( + ["crt/link.stub"] + + [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]] + ) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-bin", @@ -896,7 +1124,7 @@ def _create_local_cuda_repository(repository_ctx): outs = ["cuda/bin/" + f for f in bin_files], )) - # Select the headers based on the cuDNN version. + # Select the headers based on the cuDNN version (strip '64_' for Windows). cudnn_headers = ["cudnn.h"] if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8": cudnn_headers += [ @@ -937,15 +1165,33 @@ def _create_local_cuda_repository(repository_ctx): }, ) + cub_actual = "@cub_archive//:cub" + if int(cuda_config.cuda_version_major) >= 11: + cub_actual = ":cuda_headers" + repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], { + "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]), + "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), + "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), + "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), + "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), + "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), + "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), + "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]), + "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), + "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), + "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), + "%{cub_actual}": cub_actual, "%{copy_rules}": "\n".join(copy_rules), }, ) is_cuda_clang = _use_cuda_clang(repository_ctx) + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) tf_sysroot = _tf_sysroot(repository_ctx) should_download_clang = is_cuda_clang and _flag_enabled( @@ -956,7 +1202,7 @@ def _create_local_cuda_repository(repository_ctx): download_clang(repository_ctx, "crosstool/extra_tools") # Set up crosstool/ - cc = find_cc(repository_ctx) + cc = find_cc(repository_ctx, is_cuda_clang) cc_fullpath = cc if not should_download_clang else "crosstool/" + cc host_compiler_includes = get_cxx_inc_directories( @@ -993,7 +1239,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" - if is_cuda_clang: + if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ # Some parts of the codebase set -Werror and hit this warning, so @@ -1002,10 +1248,12 @@ def _create_local_cuda_repository(repository_ctx): """ cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes) cuda_defines["%{compiler_deps}"] = ":empty" + cuda_defines["%{win_compiler_deps}"] = ":empty" repository_ctx.file( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "", ) + repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") else: cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" cuda_defines["%{host_compiler_warnings}"] = "" @@ -1025,22 +1273,40 @@ def _create_local_cuda_repository(repository_ctx): # .d file - given that includes that are prefixed with "../" multiple # time quickly grow longer than the root of the tree, this can lead to # bazel's header check failing. - cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" + if not is_cuda_clang: + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" - nvcc_path = "%s/nvcc" % cuda_config.config["cuda_binary_dir"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext) cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files" wrapper_defines = { "%{cpu_compiler}": str(cc), "%{cuda_version}": cuda_config.cuda_version, "%{nvcc_path}": nvcc_path, - "%{gcc_host_compiler_path}": str(cc), + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": str(is_nvcc_and_clang), + "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), } repository_ctx.template( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"], wrapper_defines, ) + repository_ctx.file( + "crosstool/windows/msvc_wrapper_for_nvcc.bat", + content = "@echo OFF\n{} -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py %*".format( + get_python_bin(repository_ctx), + ), + ) + repository_ctx.template( + "crosstool/windows/msvc_wrapper_for_nvcc.py", + tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"], + wrapper_defines, + ) + + cuda_defines.update(_get_win_cuda_defines(repository_ctx)) verify_build_defines(cuda_defines) @@ -1171,12 +1437,28 @@ def _cuda_autoconf_impl(repository_ctx): repository_ctx.symlink(build_file, "BUILD") +# For @bazel_tools//tools/cpp:windows_cc_configure.bzl +_MSVC_ENVVARS = [ + "BAZEL_VC", + "BAZEL_VC_FULL_VERSION", + "BAZEL_VS", + "BAZEL_WINSDK_FULL_VERSION", + "VS90COMNTOOLS", + "VS100COMNTOOLS", + "VS110COMNTOOLS", + "VS120COMNTOOLS", + "VS140COMNTOOLS", + "VS150COMNTOOLS", + "VS160COMNTOOLS", +] + _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, _CLANG_CUDA_COMPILER_PATH, "TF_NEED_CUDA", "TF_CUDA_CLANG", + "TF_NVCC_CLANG", _TF_DOWNLOAD_CLANG, _CUDA_TOOLKIT_PATH, _CUDNN_INSTALL_PATH, @@ -1188,7 +1470,7 @@ _ENVIRONS = [ "TMP", "TMPDIR", "TF_CUDA_PATHS", -] +] + _MSVC_ENVVARS remote_cuda_configure = repository_rule( implementation = _create_local_cuda_repository, diff --git a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py index 78292c7b40237a..b88694af5c014d 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py @@ -29,6 +29,8 @@ If TF_CUDA_PATHS is not specified, a OS specific default is used: Linux: /usr/local/cuda, /usr, and paths from 'ldconfig -p'. + Windows: CUDA_PATH environment variable, or + C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\* For backwards compatibility, some libraries also use alternative base directories from other environment variables if they are specified. List of @@ -54,6 +56,7 @@ import io import os import glob +import platform import re import subprocess import sys @@ -70,6 +73,18 @@ class ConfigError(Exception): pass +def _is_linux(): + return platform.system() == "Linux" + + +def _is_windows(): + return platform.system() == "Windows" + + +def _is_macos(): + return platform.system() == "Darwin" + + def _matches_version(actual_version, required_version): """Checks whether some version meets the requirements. @@ -119,6 +134,8 @@ def _cartesian_product(first, second): def _get_ld_config_paths(): """Returns all directories from 'ldconfig -p'.""" + if not _is_linux(): + return [] ldconfig_path = which("ldconfig") or "/sbin/ldconfig" output = subprocess.check_output([ldconfig_path, "-p"]) pattern = re.compile(".* => (.*)") @@ -139,6 +156,13 @@ def _get_default_cuda_paths(cuda_version): elif not "." in cuda_version: cuda_version = cuda_version + ".*" + if _is_windows(): + return [ + os.environ.get( + "CUDA_PATH", + "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v%s\\" % + cuda_version) + ] return ["/usr/local/cuda-%s" % cuda_version, "/usr/local/cuda", "/usr", "/usr/local/cudnn"] + _get_ld_config_paths() @@ -188,8 +212,14 @@ def _find_file(base_paths, relative_paths, filepattern): def _find_library(base_paths, library_name, required_version): """Returns first valid path to the requested library.""" - filepattern = ".".join(["lib" + library_name, "so"] + - required_version.split(".")[:1]) + "*" + if _is_windows(): + filepattern = library_name + ".lib" + elif _is_macos(): + filepattern = "%s*.dylib" % (".".join(["lib" + library_name] + + required_version.split(".")[:1])) + else: + filepattern = ".".join(["lib" + library_name, "so"] + + required_version.split(".")[:1]) + "*" return _find_file(base_paths, _library_paths(), filepattern) @@ -238,7 +268,7 @@ def get_nvcc_version(path): return match.group(1) return None - nvcc_name = "nvcc" + nvcc_name = "nvcc.exe" if _is_windows() else "nvcc" nvcc_path, nvcc_version = _find_versioned_file(base_paths, [ "", "bin", @@ -528,6 +558,14 @@ def _get_legacy_path(env_name, default=[]): return _list_from_env(env_name, default) +def _normalize_path(path): + """Returns normalized path, with forward slashes on Windows.""" + path = os.path.realpath(path) + if _is_windows(): + path = path.replace("\\", "/") + return path + + def find_cuda_config(): """Returns a dictionary of CUDA library and header file paths.""" libraries = [argv.lower() for argv in sys.argv[1:]] @@ -596,7 +634,7 @@ def find_cuda_config(): for k, v in result.items(): if k.endswith("_dir") or k.endswith("_path"): - result[k] = os.path.realpath(v) + result[k] = _normalize_path(v) return result diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index 0bbbc09832db13..5c1195bada43f8 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -198,6 +198,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") @@ -345,14 +347,14 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths = [ (name, _rocm_lib_paths(repository_ctx, name, path)) for name, path in [ - ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"), + ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), ("hipsparse", rocm_config.rocm_toolkit_path), - ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"), + ("roctracer64", rocm_config.rocm_toolkit_path), ("rocsolver", rocm_config.rocm_toolkit_path), ] ] @@ -694,7 +696,7 @@ def _create_local_rocm_repository(repository_ctx): rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", - "-D__HIP_PLATFORM_HCC__", + "-D__HIP_PLATFORM_AMD__", "-DEIGEN_USE_HIP", ]) @@ -729,7 +731,7 @@ def _create_local_rocm_repository(repository_ctx): "%{hipcc_env}": _hipcc_env(repository_ctx), "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", + "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 6dd0e178ec09b7..9fca8c020bf276 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "e45cd275068c87cbd1d42d0dc89475d72798a9e8" - TFRT_SHA256 = "dd4a1440fdc8bf142c5ac00bd6227e41999a0912b2f847e932b57307f97138dd" + TFRT_COMMIT = "dbd8da33ab49ed8aa5f08ebe85bacb91341f5d61" + TFRT_SHA256 = "b95b1d17eb2e28ee0f00ae672c7377767a17e7dadde169b335aa481bb07883c7" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD new file mode 100644 index 00000000000000..dc621893ac9675 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD @@ -0,0 +1,191 @@ +"""Toolchain configs for cross-compiling TensorFlow""" + +load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +cc_toolchain_suite( + name = "cross_compile_toolchain_suite", + toolchains = { + "aarch64": ":linux_aarch64_toolchain", + "k8": ":linux_x86_toolchain", + }, +) + +filegroup( + name = "empty", + visibility = ["//visibility:public"], +) + +cc_toolchain( + name = "linux_x86_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_x86_toolchain_config", + toolchain_identifier = "linux_x86_toolchain", +) + +cc_toolchain_config( + name = "linux_x86_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt9", + compile_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mavx", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "k8", + cxx_builtin_include_directories = [ + "/dt9/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "x86_64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) + +cc_toolchain( + name = "linux_aarch64_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_aarch64_toolchain_config", + toolchain_identifier = "linux_aarch64_toolchain", +) + +cc_toolchain_config( + name = "linux_aarch64_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt10/", + compile_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mtune=generic", + "-march=armv8-a", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "aarch64", + cxx_builtin_include_directories = [ + "/dt10/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "aarch64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_aarch64_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD new file mode 100644 index 00000000000000..b6a504ba1449d6 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + exec_properties = { + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "OSFamily": "Linux", + }, +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 4554463cb90675..4b07fb5c18670d 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -200,6 +200,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_win_config( name = "windows_py37", python_bin_path = "C:/Python37/python.exe", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl index bfb4634e810328..cd346c2816def1 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl @@ -5,8 +5,9 @@ container_digests = { # TF now uses only this container "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821", # JAX manylinux2014 configs. - "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645", - "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb", + "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:b112c0c77d4172fc025420938f13ea83f3ad480c01778e743a201e5e3f4710e1", + "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:afe68c3448734cb07b16005fd9ed47d19533eb8bf5acd92863735ce24766b93b", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", @@ -98,6 +99,13 @@ containers = { "digest": container_digests["cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu18.04-manylinux2010-multipython. "rocm-ubuntu18.04-manylinux2010-multipython": { "registry": "gcr.io", diff --git a/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel b/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel index 57597e207686ff..6ccfd7a019a3ce 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel +++ b/third_party/xla/third_party/tsl/tsl/cuda/BUILD.bazel @@ -2,10 +2,6 @@ # Stubs for dynamically loading CUDA. load("//tsl/cuda:stub.bzl", "cuda_stub") -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -44,7 +40,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -67,7 +64,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -90,7 +88,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -122,7 +121,8 @@ cc_library( "//tsl:is_cuda_enabled_and_oss": [ ":cuda", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:load_library", + "//tsl/platform:logging", "@com_google_absl//absl/container:flat_hash_set", "@local_config_cuda//cuda:cuda_headers", ], @@ -151,7 +151,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@local_config_cuda//cuda:cudnn_header", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -187,7 +188,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -213,7 +215,8 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cupti_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -237,7 +240,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -261,7 +265,8 @@ cc_library( deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) @@ -287,6 +292,7 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_nccl//:nccl_headers", "//tsl/platform:dso_loader", - "//tsl/platform:env", + "//tsl/platform:logging", + "//tsl/platform:load_library", ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cublasLt.symbols b/third_party/xla/third_party/tsl/tsl/cuda/cublasLt.symbols index 7f93cfcb3ad49f..db6fa52731f784 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cublasLt.symbols +++ b/third_party/xla/third_party/tsl/tsl/cuda/cublasLt.symbols @@ -38,62 +38,119 @@ cublasLtDDDMatmulAlgoGetHeuristic cublasLtDDDMatmulAlgoGetIds cublasLtDDDMatmulAlgoInit cublasLtDestroy +cublasLtE4m3E4m3Fp32Bf16Bf16Matmul cublasLtE4m3E4m3Fp32Bf16Bf16MatmulAlgoCapGetAttribute cublasLtE4m3E4m3Fp32Bf16Bf16MatmulAlgoCheck +cublasLtE4m3E4m3Fp32Bf16Bf16MatmulAlgoGetHeuristic +cublasLtE4m3E4m3Fp32Bf16Bf16MatmulAlgoGetIds cublasLtE4m3E4m3Fp32Bf16Bf16MatmulAlgoInit +cublasLtE4m3E4m3Fp32Bf16E4m3Matmul cublasLtE4m3E4m3Fp32Bf16E4m3MatmulAlgoCapGetAttribute cublasLtE4m3E4m3Fp32Bf16E4m3MatmulAlgoCheck +cublasLtE4m3E4m3Fp32Bf16E4m3MatmulAlgoGetHeuristic +cublasLtE4m3E4m3Fp32Bf16E4m3MatmulAlgoGetIds cublasLtE4m3E4m3Fp32Bf16E4m3MatmulAlgoInit +cublasLtE4m3E4m3Fp32Fp16E4m3Matmul cublasLtE4m3E4m3Fp32Fp16E4m3MatmulAlgoCapGetAttribute cublasLtE4m3E4m3Fp32Fp16E4m3MatmulAlgoCheck +cublasLtE4m3E4m3Fp32Fp16E4m3MatmulAlgoGetHeuristic +cublasLtE4m3E4m3Fp32Fp16E4m3MatmulAlgoGetIds cublasLtE4m3E4m3Fp32Fp16E4m3MatmulAlgoInit +cublasLtE4m3E4m3Fp32Fp16Fp16Matmul cublasLtE4m3E4m3Fp32Fp16Fp16MatmulAlgoCapGetAttribute cublasLtE4m3E4m3Fp32Fp16Fp16MatmulAlgoCheck +cublasLtE4m3E4m3Fp32Fp16Fp16MatmulAlgoGetHeuristic +cublasLtE4m3E4m3Fp32Fp16Fp16MatmulAlgoGetIds cublasLtE4m3E4m3Fp32Fp16Fp16MatmulAlgoInit +cublasLtE4m3E4m3Fp32Fp32Fp32Matmul cublasLtE4m3E4m3Fp32Fp32Fp32MatmulAlgoCapGetAttribute cublasLtE4m3E4m3Fp32Fp32Fp32MatmulAlgoCheck +cublasLtE4m3E4m3Fp32Fp32Fp32MatmulAlgoGetHeuristic +cublasLtE4m3E4m3Fp32Fp32Fp32MatmulAlgoGetIds cublasLtE4m3E4m3Fp32Fp32Fp32MatmulAlgoInit +cublasLtE4m3E5m2Fp32Bf16Bf16Matmul cublasLtE4m3E5m2Fp32Bf16Bf16MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Bf16Bf16MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Bf16Bf16MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Bf16Bf16MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Bf16Bf16MatmulAlgoInit +cublasLtE4m3E5m2Fp32Bf16E4m3Matmul cublasLtE4m3E5m2Fp32Bf16E4m3MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Bf16E4m3MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Bf16E4m3MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Bf16E4m3MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Bf16E4m3MatmulAlgoInit +cublasLtE4m3E5m2Fp32Bf16E5m2Matmul cublasLtE4m3E5m2Fp32Bf16E5m2MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Bf16E5m2MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Bf16E5m2MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Bf16E5m2MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Bf16E5m2MatmulAlgoInit +cublasLtE4m3E5m2Fp32Fp16E4m3Matmul cublasLtE4m3E5m2Fp32Fp16E4m3MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Fp16E4m3MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Fp16E4m3MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Fp16E4m3MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Fp16E4m3MatmulAlgoInit +cublasLtE4m3E5m2Fp32Fp16E5m2Matmul cublasLtE4m3E5m2Fp32Fp16E5m2MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Fp16E5m2MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Fp16E5m2MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Fp16E5m2MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Fp16E5m2MatmulAlgoInit +cublasLtE4m3E5m2Fp32Fp16Fp16Matmul cublasLtE4m3E5m2Fp32Fp16Fp16MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Fp16Fp16MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Fp16Fp16MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Fp16Fp16MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Fp16Fp16MatmulAlgoInit +cublasLtE4m3E5m2Fp32Fp32Fp32Matmul cublasLtE4m3E5m2Fp32Fp32Fp32MatmulAlgoCapGetAttribute cublasLtE4m3E5m2Fp32Fp32Fp32MatmulAlgoCheck +cublasLtE4m3E5m2Fp32Fp32Fp32MatmulAlgoGetHeuristic +cublasLtE4m3E5m2Fp32Fp32Fp32MatmulAlgoGetIds cublasLtE4m3E5m2Fp32Fp32Fp32MatmulAlgoInit +cublasLtE5m2E4m3Fp32Bf16Bf16Matmul cublasLtE5m2E4m3Fp32Bf16Bf16MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Bf16Bf16MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Bf16Bf16MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Bf16Bf16MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Bf16Bf16MatmulAlgoInit +cublasLtE5m2E4m3Fp32Bf16E4m3Matmul cublasLtE5m2E4m3Fp32Bf16E4m3MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Bf16E4m3MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Bf16E4m3MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Bf16E4m3MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Bf16E4m3MatmulAlgoInit +cublasLtE5m2E4m3Fp32Bf16E5m2Matmul cublasLtE5m2E4m3Fp32Bf16E5m2MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Bf16E5m2MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Bf16E5m2MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Bf16E5m2MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Bf16E5m2MatmulAlgoInit +cublasLtE5m2E4m3Fp32Fp16E4m3Matmul cublasLtE5m2E4m3Fp32Fp16E4m3MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Fp16E4m3MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Fp16E4m3MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Fp16E4m3MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Fp16E4m3MatmulAlgoInit +cublasLtE5m2E4m3Fp32Fp16E5m2Matmul cublasLtE5m2E4m3Fp32Fp16E5m2MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Fp16E5m2MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Fp16E5m2MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Fp16E5m2MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Fp16E5m2MatmulAlgoInit +cublasLtE5m2E4m3Fp32Fp16Fp16Matmul cublasLtE5m2E4m3Fp32Fp16Fp16MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Fp16Fp16MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Fp16Fp16MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Fp16Fp16MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Fp16Fp16MatmulAlgoInit +cublasLtE5m2E4m3Fp32Fp32Fp32Matmul cublasLtE5m2E4m3Fp32Fp32Fp32MatmulAlgoCapGetAttribute cublasLtE5m2E4m3Fp32Fp32Fp32MatmulAlgoCheck +cublasLtE5m2E4m3Fp32Fp32Fp32MatmulAlgoGetHeuristic +cublasLtE5m2E4m3Fp32Fp32Fp32MatmulAlgoGetIds cublasLtE5m2E4m3Fp32Fp32Fp32MatmulAlgoInit cublasLtGetCudartVersion cublasLtGetProperty diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc index df4e73bebc126c..d078aa2f2c55ee 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cublasLt_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. @@ -33,8 +34,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc index 814d64d75d8d61..fe3cec911ca186 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cublas_stub.cc @@ -24,7 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuBLAS API by forwarding to cuBLAS loaded from the DSO. // Note that it does not implement the v1 interface. @@ -43,8 +44,7 @@ void *GetDsoHandle() { void *LoadSymbol(const char *symbol_name) { void *symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cuda.symbols b/third_party/xla/third_party/tsl/tsl/cuda/cuda.symbols index 558d11cafdbc99..97e1d00ebd57ae 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cuda.symbols +++ b/third_party/xla/third_party/tsl/tsl/cuda/cuda.symbols @@ -10,6 +10,10 @@ cuArrayGetDescriptor_v2 cuArrayGetMemoryRequirements cuArrayGetPlane cuArrayGetSparseProperties +cuCoredumpGetAttribute +cuCoredumpGetAttributeGlobal +cuCoredumpSetAttribute +cuCoredumpSetAttributeGlobal cuCtxAttach cuCtxCreate cuCtxCreate_v2 @@ -36,6 +40,7 @@ cuCtxPushCurrent_v2 cuCtxResetPersistingL2Cache cuCtxSetCacheConfig cuCtxSetCurrent +cuCtxSetFlags cuCtxSetLimit cuCtxSetSharedMemConfig cuCtxSynchronize @@ -99,6 +104,7 @@ cuExternalMemoryGetMappedMipmappedArray cuFlushGPUDirectRDMAWrites cuFuncGetAttribute cuFuncGetModule +cuFuncGetName cuFuncSetAttribute cuFuncSetBlockShape cuFuncSetCacheConfig @@ -128,6 +134,7 @@ cuGetProcAddress_v2 cuGraphAddBatchMemOpNode cuGraphAddChildGraphNode cuGraphAddDependencies +cuGraphAddDependencies_v2 cuGraphAddEmptyNode cuGraphAddEventRecordNode cuGraphAddEventWaitNode @@ -140,10 +147,13 @@ cuGraphAddMemAllocNode cuGraphAddMemFreeNode cuGraphAddMemcpyNode cuGraphAddMemsetNode +cuGraphAddNode +cuGraphAddNode_v2 cuGraphBatchMemOpNodeGetParams cuGraphBatchMemOpNodeSetParams cuGraphChildGraphNodeGetGraph cuGraphClone +cuGraphConditionalHandleCreate cuGraphCreate cuGraphDebugDotPrint cuGraphDestroy @@ -165,6 +175,7 @@ cuGraphExecKernelNodeSetParams cuGraphExecKernelNodeSetParams_v2 cuGraphExecMemcpyNodeSetParams cuGraphExecMemsetNodeSetParams +cuGraphExecNodeSetParams cuGraphExecUpdate cuGraphExecUpdate_v2 cuGraphExternalSemaphoresSignalNodeGetParams @@ -172,6 +183,7 @@ cuGraphExternalSemaphoresSignalNodeSetParams cuGraphExternalSemaphoresWaitNodeGetParams cuGraphExternalSemaphoresWaitNodeSetParams cuGraphGetEdges +cuGraphGetEdges_v2 cuGraphGetNodes cuGraphGetRootNodes cuGraphHostNodeGetParams @@ -198,12 +210,16 @@ cuGraphMemsetNodeGetParams cuGraphMemsetNodeSetParams cuGraphNodeFindInClone cuGraphNodeGetDependencies +cuGraphNodeGetDependencies_v2 cuGraphNodeGetDependentNodes +cuGraphNodeGetDependentNodes_v2 cuGraphNodeGetEnabled cuGraphNodeGetType cuGraphNodeSetEnabled +cuGraphNodeSetParams cuGraphReleaseUserObject cuGraphRemoveDependencies +cuGraphRemoveDependencies_v2 cuGraphRetainUserObject cuGraphUpload cuGraphUpload_ptsz @@ -235,6 +251,7 @@ cuIpcOpenMemHandle cuIpcOpenMemHandle_v2 cuKernelGetAttribute cuKernelGetFunction +cuKernelGetName cuKernelSetAttribute cuKernelSetCacheConfig cuLaunch @@ -268,6 +285,7 @@ cuLinkDestroy cuMemAddressFree cuMemAddressReserve cuMemAdvise +cuMemAdvise_v2 cuMemAlloc cuMemAllocAsync cuMemAllocAsync_ptsz @@ -320,6 +338,8 @@ cuMemPoolSetAttribute cuMemPoolTrimTo cuMemPrefetchAsync cuMemPrefetchAsync_ptsz +cuMemPrefetchAsync_v2 +cuMemPrefetchAsync_v2_ptsz cuMemRangeGetAttribute cuMemRangeGetAttributes cuMemRelease @@ -438,6 +458,12 @@ cuModuleLoadData cuModuleLoadDataEx cuModuleLoadFatBinary cuModuleUnload +cuMulticastAddDevice +cuMulticastBindAddr +cuMulticastBindMem +cuMulticastCreate +cuMulticastGetGranularity +cuMulticastUnbind cuOccupancyAvailableDynamicSMemPerBlock cuOccupancyMaxActiveBlocksPerMultiprocessor cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags @@ -467,6 +493,8 @@ cuStreamBatchMemOp_ptsz cuStreamBatchMemOp_v2 cuStreamBatchMemOp_v2_ptsz cuStreamBeginCapture +cuStreamBeginCaptureToGraph +cuStreamBeginCaptureToGraph_ptsz cuStreamBeginCapture_ptsz cuStreamBeginCapture_v2 cuStreamBeginCapture_v2_ptsz @@ -484,6 +512,8 @@ cuStreamGetCaptureInfo cuStreamGetCaptureInfo_ptsz cuStreamGetCaptureInfo_v2 cuStreamGetCaptureInfo_v2_ptsz +cuStreamGetCaptureInfo_v3 +cuStreamGetCaptureInfo_v3_ptsz cuStreamGetCtx cuStreamGetCtx_ptsz cuStreamGetFlags @@ -502,6 +532,8 @@ cuStreamSynchronize cuStreamSynchronize_ptsz cuStreamUpdateCaptureDependencies cuStreamUpdateCaptureDependencies_ptsz +cuStreamUpdateCaptureDependencies_v2 +cuStreamUpdateCaptureDependencies_v2_ptsz cuStreamWaitEvent cuStreamWaitEvent_ptsz cuStreamWaitValue32 @@ -574,10 +606,30 @@ cuVDPAUGetDevice cuWaitExternalSemaphoresAsync cuWaitExternalSemaphoresAsync_ptsz cudbgApiAttach +cudbgApiClientPid +cudbgApiClientRevision cudbgApiDetach cudbgApiInit +cudbgAttachHandlerAvailable +cudbgDebuggerCapabilities +cudbgDebuggerInitialized +cudbgDetachSuspendedDevicesMask +cudbgEnableIntegratedMemcheck +cudbgEnableLaunchBlocking +cudbgEnablePreemptionDebugging cudbgGetAPI cudbgGetAPIVersion +cudbgInjectionPath +cudbgIpcFlag cudbgMain cudbgReportDriverApiError +cudbgReportDriverApiErrorFlags cudbgReportDriverInternalError +cudbgReportedDriverApiErrorCode +cudbgReportedDriverApiErrorFuncNameAddr +cudbgReportedDriverApiErrorFuncNameSize +cudbgReportedDriverInternalErrorCode +cudbgResumeForAttachDetach +cudbgRpcEnabled +cudbgSessionId +cudbgUseExternalDebugger diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc index a199d4cc700442..298d493db97d15 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cuda_stub.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the CUDA driver API by forwarding to CUDA loaded from the DSO. @@ -36,8 +37,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cudart.symbols b/third_party/xla/third_party/tsl/tsl/cuda/cudart.symbols index 69b990cb3879b5..443b8057e44f0e 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cudart.symbols +++ b/third_party/xla/third_party/tsl/tsl/cuda/cudart.symbols @@ -80,6 +80,7 @@ cudaFreeAsync_ptsz cudaFreeHost cudaFreeMipmappedArray cudaFuncGetAttributes +cudaFuncGetName cudaFuncSetAttribute cudaFuncSetCacheConfig cudaFuncSetSharedMemConfig @@ -115,6 +116,7 @@ cudaGetTextureObjectResourceViewDesc cudaGetTextureObjectTextureDesc cudaGraphAddChildGraphNode cudaGraphAddDependencies +cudaGraphAddDependencies_v2 cudaGraphAddEmptyNode cudaGraphAddEventRecordNode cudaGraphAddEventWaitNode @@ -130,8 +132,10 @@ cudaGraphAddMemcpyNodeFromSymbol cudaGraphAddMemcpyNodeToSymbol cudaGraphAddMemsetNode cudaGraphAddNode +cudaGraphAddNode_v2 cudaGraphChildGraphNodeGetGraph cudaGraphClone +cudaGraphConditionalHandleCreate cudaGraphCreate cudaGraphDebugDotPrint cudaGraphDestroy @@ -161,6 +165,7 @@ cudaGraphExternalSemaphoresSignalNodeSetParams cudaGraphExternalSemaphoresWaitNodeGetParams cudaGraphExternalSemaphoresWaitNodeSetParams cudaGraphGetEdges +cudaGraphGetEdges_v2 cudaGraphGetNodes cudaGraphGetRootNodes cudaGraphHostNodeGetParams @@ -187,13 +192,16 @@ cudaGraphMemsetNodeGetParams cudaGraphMemsetNodeSetParams cudaGraphNodeFindInClone cudaGraphNodeGetDependencies +cudaGraphNodeGetDependencies_v2 cudaGraphNodeGetDependentNodes +cudaGraphNodeGetDependentNodes_v2 cudaGraphNodeGetEnabled cudaGraphNodeGetType cudaGraphNodeSetEnabled cudaGraphNodeSetParams cudaGraphReleaseUserObject cudaGraphRemoveDependencies +cudaGraphRemoveDependencies_v2 cudaGraphRetainUserObject cudaGraphUpload cudaGraphUpload_ptsz @@ -348,6 +356,8 @@ cudaStreamAddCallback_ptsz cudaStreamAttachMemAsync cudaStreamAttachMemAsync_ptsz cudaStreamBeginCapture +cudaStreamBeginCaptureToGraph +cudaStreamBeginCaptureToGraph_ptsz cudaStreamBeginCapture_ptsz cudaStreamCopyAttributes cudaStreamCopyAttributes_ptsz @@ -363,6 +373,8 @@ cudaStreamGetCaptureInfo cudaStreamGetCaptureInfo_ptsz cudaStreamGetCaptureInfo_v2 cudaStreamGetCaptureInfo_v2_ptsz +cudaStreamGetCaptureInfo_v3 +cudaStreamGetCaptureInfo_v3_ptsz cudaStreamGetFlags cudaStreamGetFlags_ptsz cudaStreamGetId @@ -379,6 +391,8 @@ cudaStreamSynchronize cudaStreamSynchronize_ptsz cudaStreamUpdateCaptureDependencies cudaStreamUpdateCaptureDependencies_ptsz +cudaStreamUpdateCaptureDependencies_v2 +cudaStreamUpdateCaptureDependencies_v2_ptsz cudaStreamWaitEvent cudaStreamWaitEvent_ptsz cudaThreadExchangeStreamCaptureMode diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc index a3797b5c751cd8..5ec2fabd84a712 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cudart_stub.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" namespace { void *GetDsoHandle() { @@ -39,8 +40,8 @@ void *GetDsoHandle() { void *LoadSymbol(const char *symbol_name) { void *symbol = nullptr; - auto env = tsl::Env::Default(); - env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError(); + tsl::internal::GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol) + .IgnoreError(); return symbol; } diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc index f3cab179eb0b71..1c85b1ea684a28 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cudnn_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "third_party/gpus/cudnn/cudnn.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuDNN API by forwarding to cuDNN loaded from the DSO. @@ -38,8 +39,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cufft.symbols b/third_party/xla/third_party/tsl/tsl/cuda/cufft.symbols index 605815200bd90e..0f18127df42af5 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cufft.symbols +++ b/third_party/xla/third_party/tsl/tsl/cuda/cufft.symbols @@ -1,7 +1,6 @@ cufftCreate cufftDebug cufftDestroy -cufftEnterCS cufftEstimate1d cufftEstimate2d cufftEstimate3d @@ -20,11 +19,9 @@ cufftGetSize3d cufftGetSizeMany cufftGetSizeMany64 cufftGetVersion -cufftLeaveCS cufftMakePlan1d cufftMakePlan2d cufftMakePlan3d -cufftMakePlanGuru64 cufftMakePlanMany cufftMakePlanMany64 cufftPlan1d diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc index 8f5c1b0d687337..275560027af19b 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cufft_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cufftXt.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cuFFT API by forwarding to cuFFT loaded from the DSO. @@ -37,8 +38,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc index 9e632010d83a7a..aab8217aa3ebe5 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cupti_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/cuda.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the CUPTI API by forwarding to CUPTI loaded from the DSO. @@ -38,8 +39,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc index d11601b3bd4217..418ce47311d718 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cusolver_stub.cc @@ -16,7 +16,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolverSp.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cusolver API by forwarding to cusolver loaded from the DSO. @@ -38,8 +39,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc index 16141e51e2613b..8b545cd0c1c1d8 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/cusparse_stub.cc @@ -15,7 +15,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusparse.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the cusparse API by forwarding to cusparse loaded from the DSO. @@ -37,8 +38,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc b/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc index 0ebae2f3c2b2eb..462ab127ee446b 100644 --- a/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/cuda/nccl_stub.cc @@ -18,7 +18,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/nccl/nccl.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" +#include "tsl/platform/load_library.h" +#include "tsl/platform/logging.h" // Implements the nccl API by forwarding to nccl loaded from a DSO. @@ -40,8 +41,7 @@ void* GetDsoHandle() { void* LoadSymbol(const char* symbol_name) { void* symbol = nullptr; if (auto handle = GetDsoHandle()) { - tsl::Env::Default() - ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + tsl::internal::GetSymbolFromLibrary(handle, symbol_name, &symbol) .IgnoreError(); } return symbol; diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc index 0b916e65aaa208..9d92bdccceb2c9 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc @@ -62,6 +62,7 @@ constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds constexpr size_t kOngoingBarriersSoftLimit = 20; constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck"; constexpr int kPendingTaskLogLimit = 20; +constexpr int kPendingStragglerLogLimit = 3; std::string GetTaskName(absl::string_view job_name, int task_id) { return strings::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id); @@ -104,6 +105,9 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void SetDeviceAggregationFunction( std::function post_aggregate_device_fn) override; + + void LogConnectStatusLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + Status RegisterTask(const CoordinatedTask& task, uint64_t incarnation) override; void WaitForAllTasks(const CoordinatedTask& task, const DeviceInfo& devices, @@ -519,6 +523,26 @@ void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { } } +// Helper to log progress to having waited for all tasks. +void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const { + const int num_tasks = cluster_state_.size(); + int pending_tasks = 0; + std::vector task_names; + for (const auto& [task_name, task_state] : cluster_state_) { + if (task_state->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { + pending_tasks++; + if (task_names.size() < kPendingStragglerLogLimit) { + task_names.push_back(task_name); + } + } + } + LOG(INFO) << "Waiting for " << pending_tasks << "/" << num_tasks + << " tasks to connect."; + if (!task_names.empty()) { + LOG(INFO) << "Example stragglers:\n" << absl::StrJoin(task_names, "\n"); + } +} + Status CoordinationServiceStandaloneImpl::RegisterTask( const CoordinatedTask& task, uint64_t incarnation) { const std::string& task_name = GetTaskName(task); @@ -553,6 +577,7 @@ Status CoordinationServiceStandaloneImpl::RegisterTask( LOG(INFO) << task_name << " has connected to coordination service. Incarnation: " << incarnation; + LogConnectStatusLocked(); return OkStatus(); } else if (task_state == CoordinatedTaskState::TASKSTATE_CONNECTED) { // This may happen if the service processes the initial RegisterTask(), @@ -565,6 +590,7 @@ Status CoordinationServiceStandaloneImpl::RegisterTask( LOG(INFO) << task_name << " has connected to coordination service with the same " << "incarnation again: " << incarnation; + LogConnectStatusLocked(); return OkStatus(); } else { error_message = diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc index a45213d1817624..79065f7a9118ab 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -92,29 +93,27 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { Status Shutdown() override; Status Reset() override; - StatusOr GetKeyValue(const std::string& key) override; + StatusOr GetKeyValue(std::string_view key) override; StatusOr GetKeyValue(const char* key, int64_t key_size) override; - StatusOr GetKeyValue(const std::string& key, + StatusOr GetKeyValue(std::string_view key, absl::Duration timeout) override; std::shared_ptr GetKeyValueAsync( - const std::string& key, StatusOrValueCallback done) override; - StatusOr TryGetKeyValue(const std::string& key) override; + std::string_view key, StatusOrValueCallback done) override; + StatusOr TryGetKeyValue(std::string_view key) override; StatusOr> GetKeyValueDir( - const std::string& key) override; - void GetKeyValueDirAsync(const std::string& key, + std::string_view key) override; + void GetKeyValueDirAsync(std::string_view key, StatusOrValueDirCallback done) override; - Status InsertKeyValue(const std::string& key, - const std::string& value) override; + Status InsertKeyValue(std::string_view key, std::string_view value) override; Status InsertKeyValue(const char* key, int64_t key_size, const char* value, int64_t value_size) override; - Status DeleteKeyValue(const std::string& key) override; + Status DeleteKeyValue(std::string_view key) override; Status DeleteKeyValue(const char* key, int64_t key_size) override; - Status UpdateKeyValue(const std::string& key, - const std::string& value) override; + Status UpdateKeyValue(std::string_view key, std::string_view value) override; - Status StartWatchKey(const std::string& key, + Status StartWatchKey(std::string_view key, ChangedKeyValuesCallback on_change) override; - Status StopWatchKey(const std::string& key) override; + Status StopWatchKey(std::string_view key) override; Status WaitAtBarrier(const std::string& barrier_id, absl::Duration timeout, const std::vector& tasks) override; void WaitAtBarrierAsync(const std::string& barrier_id, absl::Duration timeout, @@ -128,7 +127,7 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { protected: void SetError(const Status& error) override; - Status ActivateWatch(const std::string& key, + Status ActivateWatch(std::string_view key, const std::map&) override; // Returns an error if agent is not running. If `allow_disconnected` is true, // returns OK even if the agent is in DISCONNECTED state. @@ -567,17 +566,17 @@ Status CoordinationServiceAgentImpl::Reset() { } StatusOr CoordinationServiceAgentImpl::GetKeyValue( - const std::string& key) { + std::string_view key) { return GetKeyValue(key, /*timeout=*/absl::InfiniteDuration()); } StatusOr CoordinationServiceAgentImpl::GetKeyValue( const char* key, int64_t key_size) { - return GetKeyValue(std::string(key, key_size)); + return GetKeyValue(std::string_view(key, key_size)); } StatusOr CoordinationServiceAgentImpl::GetKeyValue( - const std::string& key, absl::Duration timeout) { + std::string_view key, absl::Duration timeout) { auto n = std::make_shared(); auto result = std::make_shared>(); GetKeyValueAsync(key, @@ -597,9 +596,9 @@ StatusOr CoordinationServiceAgentImpl::GetKeyValue( } std::shared_ptr CoordinationServiceAgentImpl::GetKeyValueAsync( - const std::string& key, StatusOrValueCallback done) { + std::string_view key, StatusOrValueCallback done) { auto request = std::make_shared(); - request->set_key(key); + request->set_key(key.data(), key.size()); VLOG(3) << "GetKeyValueRequest: " << request->DebugString(); auto response = std::make_shared(); auto call_opts = std::make_shared(); @@ -633,33 +632,31 @@ std::shared_ptr CoordinationServiceAgentImpl::GetKeyValueAsync( } StatusOr CoordinationServiceAgentImpl::TryGetKeyValue( - const std::string& key) { + std::string_view key) { absl::Notification n; StatusOr result; TryGetKeyValueRequest request; - request.set_key(key); + request.set_key(key.data(), key.size()); VLOG(3) << "TryGetKeyValueRequest: " << request.DebugString(); TryGetKeyValueResponse response; - leader_client_->TryGetKeyValueAsync(&request, &response, - [&](const Status& s) { - if (s.ok()) { - result = response.kv().value(); - VLOG(3) << "TryGetKeyValueResponse: " - << result.value(); - } else { - result = s; - VLOG(3) << "TryGetKeyValueResponse: " - << s; - } - n.Notify(); - }); + leader_client_->TryGetKeyValueAsync( + &request, &response, [&](const Status& s) { + if (s.ok()) { + result = response.kv().value(); + VLOG(3) << "TryGetKeyValueResponse: " << result.value(); + } else { + result = s; + VLOG(3) << "TryGetKeyValueResponse: " << s; + } + n.Notify(); + }); n.WaitForNotification(); return result; } StatusOr> -CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) { +CoordinationServiceAgentImpl::GetKeyValueDir(std::string_view key) { absl::Notification n; StatusOr> result; GetKeyValueDirAsync( @@ -673,9 +670,9 @@ CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) { } void CoordinationServiceAgentImpl::GetKeyValueDirAsync( - const std::string& key, StatusOrValueDirCallback done) { + std::string_view key, StatusOrValueDirCallback done) { auto request = std::make_shared(); - request->set_directory_key(key); + request->set_directory_key(key.data(), key.size()); VLOG(3) << "GetKeyValueDirRequest: " << request->DebugString(); auto response = std::make_shared(); leader_client_->GetKeyValueDirAsync( @@ -694,8 +691,8 @@ void CoordinationServiceAgentImpl::GetKeyValueDirAsync( }); } -Status CoordinationServiceAgentImpl::InsertKeyValue(const std::string& key, - const std::string& value) { +Status CoordinationServiceAgentImpl::InsertKeyValue(std::string_view key, + std::string_view value) { InsertKeyValueRequest request; request.mutable_kv()->set_key(key.data(), key.size()); request.mutable_kv()->set_value(value.data(), value.size()); @@ -717,13 +714,13 @@ Status CoordinationServiceAgentImpl::InsertKeyValue(const char* key, int64_t key_size, const char* value, int64_t value_size) { - return InsertKeyValue(std::string(key, key_size), - std::string(value, value_size)); + return InsertKeyValue(std::string_view(key, key_size), + std::string_view(value, value_size)); } -Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) { +Status CoordinationServiceAgentImpl::DeleteKeyValue(std::string_view key) { DeleteKeyValueRequest request; - request.set_key(key); + request.set_key(key.data(), key.size()); request.set_is_directory(true); VLOG(3) << "DeleteKeyValueRequest: " << request.DebugString(); DeleteKeyValueResponse response; @@ -741,23 +738,23 @@ Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) { Status CoordinationServiceAgentImpl::DeleteKeyValue(const char* key, int64_t key_size) { - return DeleteKeyValue(std::string(key, key_size)); + return DeleteKeyValue(std::string_view(key, key_size)); } -Status CoordinationServiceAgentImpl::UpdateKeyValue(const std::string& key, - const std::string& value) { +Status CoordinationServiceAgentImpl::UpdateKeyValue(std::string_view key, + std::string_view value) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::UpdateKeyValue is not implemented.")); } Status CoordinationServiceAgentImpl::StartWatchKey( - const std::string& key, + std::string_view key, CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::StartWatchKey is not implemented.")); } -Status CoordinationServiceAgentImpl::StopWatchKey(const std::string& key) { +Status CoordinationServiceAgentImpl::StopWatchKey(std::string_view key) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::StopWatchKey is not implemented.")); } @@ -774,7 +771,7 @@ void CoordinationServiceAgentImpl::SetError(const Status& error) { } Status CoordinationServiceAgentImpl::ActivateWatch( - const std::string& key, const std::map& kvs) { + std::string_view key, const std::map& kvs) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::ActivateWatch is not implemented.")); } diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h index a567272f9d72ef..f94e6ac9dcb209 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -168,19 +169,19 @@ class CoordinationServiceAgent { // If the key-value is not inserted yet, this is a blocking call that waits // until the corresponding key is inserted. // - errors::DeadlineExceeded: timed out waiting for key. - virtual StatusOr GetKeyValue(const std::string& key) = 0; + virtual StatusOr GetKeyValue(std::string_view key) = 0; virtual StatusOr GetKeyValue(const char* key, int64_t key_size) = 0; - virtual StatusOr GetKeyValue(const std::string& key, + virtual StatusOr GetKeyValue(std::string_view key, absl::Duration timeout) = 0; // Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and // `call_opts->ClearCancelCallback()`. virtual std::shared_ptr GetKeyValueAsync( - const std::string& key, StatusOrValueCallback done) = 0; + std::string_view, StatusOrValueCallback done) = 0; // Get config key-value from the service. // - errors::NotFound: the requested key does not exist. - virtual StatusOr TryGetKeyValue(const std::string& key) = 0; + virtual StatusOr TryGetKeyValue(std::string_view key) = 0; // Get all values under a directory (key). // A value is considered to be in the directory if its key is prefixed with @@ -188,30 +189,30 @@ class CoordinationServiceAgent { // This is not a blocking call. If no keys are found, an empty vector is // returned immediately. virtual StatusOr> GetKeyValueDir( - const std::string& key) = 0; - virtual void GetKeyValueDirAsync(const std::string& key, + std::string_view key) = 0; + virtual void GetKeyValueDirAsync(std::string_view key, StatusOrValueDirCallback done) = 0; // Insert config key-value to the service. // - errors::AlreadyExists: key is already set. - virtual Status InsertKeyValue(const std::string& key, - const std::string& value) = 0; + virtual Status InsertKeyValue(std::string_view key, + std::string_view value) = 0; virtual Status InsertKeyValue(const char* key, int64_t key_size, const char* value, int64_t value_size) = 0; // Delete config keys in the coordination service. - virtual Status DeleteKeyValue(const std::string& key) = 0; + virtual Status DeleteKeyValue(std::string_view key) = 0; virtual Status DeleteKeyValue(const char* key, int64_t key_size) = 0; // Update the value of a config key. - virtual Status UpdateKeyValue(const std::string& key, - const std::string& value) = 0; + virtual Status UpdateKeyValue(std::string_view key, + std::string_view value) = 0; // Register a callback that will be invoked when the key or keys under the key // directory are changed (inserted, deleted, or updated). - virtual Status StartWatchKey(const std::string& key, + virtual Status StartWatchKey(std::string_view key, ChangedKeyValuesCallback on_change) = 0; - virtual Status StopWatchKey(const std::string& key) = 0; + virtual Status StopWatchKey(std::string_view key) = 0; // Blocks until all (or a subset of) tasks are at the barrier or the barrier // fails. @@ -273,7 +274,7 @@ class CoordinationServiceAgent { virtual void SetError(const Status& error) = 0; // Activate the key-value callback watch. - virtual Status ActivateWatch(const std::string& key, + virtual Status ActivateWatch(std::string_view, const std::map&) = 0; private: diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h index 8582594922adf2..ba67299b57a952 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collected_metrics.h @@ -90,6 +90,7 @@ struct Point { int64_t int64_value; string string_value; bool bool_value; + double double_value; HistogramProto histogram_value; Percentiles percentiles_value; diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h index d988d2f19f15ad..7af6c87e51f0bb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/collection_registry.h @@ -352,6 +352,18 @@ inline void CollectValue(Percentiles value, Point* const point) { point->percentiles_value = std::move(value); } +template <> +inline void CollectValue(double value, Point* const point) { + point->value_type = ValueType::kDouble; + point->double_value = value; +} + +template <> +inline void CollectValue(std::function value_fn, Point* const point) { + point->value_type = ValueType::kDouble; + point->double_value = value_fn(); +} + // Used by the CollectionRegistry class to collect all the values of all the // metrics in the registry. This is an implementation detail of the // CollectionRegistry class, please do not depend on this. diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h index 93cbe9aa928df0..0b69383b5f2d13 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/gauge.h @@ -65,8 +65,10 @@ class Gauge { std::is_same::value || std::is_same >::value || std::is_same >::value || - std::is_same >::value, - "Gauge only allows bool, int64, and string types."); + std::is_same >::value || + std::is_same >::value || + std::is_same::value, + "Gauge only allows bool, int64, double and string types."); return new Gauge(); } @@ -296,8 +298,10 @@ Gauge* Gauge::New( std::is_same::value || std::is_same >::value || std::is_same >::value || - std::is_same >::value, - "Gauge only allows bool, int64, and string types."); + std::is_same >::value || + std::is_same >::value || + std::is_same::value, + "Gauge only allows bool, int64, double, and string types."); return new Gauge( MetricDef( std::forward(metric_def_args)...)); diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h index f8c21c360a2b09..ab454664691b1e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/metric_def.h @@ -47,7 +47,8 @@ enum class ValueType : int { kHistogram, kString, kBool, - kPercentiles + kPercentiles, + kDouble }; // Everything in the internal namespace is implementation details. Do not depend @@ -97,6 +98,16 @@ inline ValueType GetValueType>() { return ValueType::kBool; } +template <> +inline ValueType GetValueType() { + return ValueType::kDouble; +} + +template <> +inline ValueType GetValueType>() { + return ValueType::kDouble; +} + } // namespace internal // Abstract base class for a metric definition. diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc index c25c354fd37cac..1de5eb8031623d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc @@ -82,6 +82,7 @@ class CPUIDInfo { : have_adx_(0), have_aes_(0), have_amx_bf16_(0), + have_amx_fp16_(0), have_amx_int8_(0), have_amx_tile_(0), have_avx_(0), @@ -98,8 +99,11 @@ class CPUIDInfo { have_avx512_4vnniw_(0), have_avx512_4fmaps_(0), have_avx512_bf16_(0), + have_avx512_fp16_(0), have_avx512_vnni_(0), have_avx_vnni_(0), + have_avx_vnni_int8_(0), + have_avx_ne_convert_(0), have_bmi1_(0), have_bmi2_(0), have_cmov_(0), @@ -226,12 +230,19 @@ class CPUIDInfo { cpuid->have_amx_int8_ = (edx >> 25) & 0x1; cpuid->have_amx_bf16_ = (edx >> 22) & 0x1; + // Check for avx512_fp16 using information from Xbyak in oneDNN: + // https://github.com/oneapi-src/oneDNN/blob/acf8d214cedfe7e24c9446bacc1f9f648c9273f8/src/cpu/x64/xbyak/xbyak_util.h#L516 + cpuid->have_avx512_fp16_ = have_avx512 && ((edx >> 23) & 0x1); + // Get more Structured Extended Feature info by issuing CPUID with // sub-leaf = 1 (eax = 7, ecx = 1) if (kMaxNumSubLeaves >= 1) { GETCPUID(eax, ebx, ecx, edx, 7, 1); cpuid->have_avx_vnni_ = (eax >> 4) & 0x1; cpuid->have_avx512_bf16_ = have_avx512 && ((eax >> 5) & 0x1); + cpuid->have_amx_fp16_ = (eax >> 21) & 0x1; + cpuid->have_avx_vnni_int8_ = (edx >> 4) & 0x1; + cpuid->have_avx_ne_convert_ = (edx >> 5) & 0x1; } } @@ -242,6 +253,7 @@ class CPUIDInfo { case ADX: return cpuid->have_adx_; case AES: return cpuid->have_aes_; case AMX_BF16: return cpuid->have_amx_bf16_; + case AMX_FP16: return cpuid->have_amx_fp16_; case AMX_INT8: return cpuid->have_amx_int8_; case AMX_TILE: return cpuid->have_amx_tile_; case AVX2: return cpuid->have_avx2_; @@ -258,8 +270,11 @@ class CPUIDInfo { case AVX512_4VNNIW: return cpuid->have_avx512_4vnniw_; case AVX512_4FMAPS: return cpuid->have_avx512_4fmaps_; case AVX512_BF16: return cpuid->have_avx512_bf16_; + case AVX512_FP16: return cpuid->have_avx512_fp16_; case AVX512_VNNI: return cpuid->have_avx512_vnni_; case AVX_VNNI: return cpuid->have_avx_vnni_; + case AVX_VNNI_INT8: return cpuid->have_avx_vnni_int8_; + case AVX_NE_CONVERT: return cpuid->have_avx_ne_convert_; case BMI1: return cpuid->have_bmi1_; case BMI2: return cpuid->have_bmi2_; case CMOV: return cpuid->have_cmov_; @@ -297,6 +312,7 @@ class CPUIDInfo { int have_adx_ : 1; int have_aes_ : 1; int have_amx_bf16_ : 1; + int have_amx_fp16_ : 1; int have_amx_int8_ : 1; int have_amx_tile_ : 1; int have_avx_ : 1; @@ -313,8 +329,11 @@ class CPUIDInfo { int have_avx512_4vnniw_ : 1; int have_avx512_4fmaps_ : 1; int have_avx512_bf16_ : 1; + int have_avx512_fp16_ : 1; int have_avx512_vnni_ : 1; int have_avx_vnni_ : 1; + int have_avx_vnni_int8_ : 1; + int have_avx_ne_convert_ : 1; int have_bmi1_ : 1; int have_bmi2_ : 1; int have_cmov_ : 1; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h index e0b0d66bb11118..68506b1d34ae8e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h @@ -132,6 +132,11 @@ enum CPUFeature { AMX_TILE = 41, // Tile configuration and load/store AMX_INT8 = 42, // Int8 tile matrix multiplication AMX_BF16 = 43, // Bfloat16 tile matrix multiplication + + AVX512_FP16 = 44, // Float16 neural network + AMX_FP16 = 45, // Float16 tile matrix multiplication + AVX_NE_CONVERT = 46, // Instructions for faster bfloat16, float16 convert. + AVX_VNNI_INT8 = 47, // VNNI instructions for combinations of u8, s8 dtypes. }; enum Aarch64CPU { diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index e56abd66607093..aac69570b88c4f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -82,12 +82,11 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:env", - "//tsl/platform:errors", + "//tsl/platform:load_library", "//tsl/platform:logging", "//tsl/platform:path", - "//tsl/platform:status", - "//tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -247,8 +246,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", + "@com_google_absl//absl/status", ], ) @@ -362,7 +360,6 @@ cc_library( "//tsl:with_numa_support": ["TENSORFLOW_USE_NUMA"], "//conditions:default": [], }), - features = ["-layering_check"], tags = [ "manual", "no_oss", diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD index 93f35c45c0569d..2d6dfda0028a1b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD @@ -117,10 +117,16 @@ cc_library( data = [ "@local_config_cuda//cuda:cudart", ], - linkopts = [ - "-Wl,-rpath,../local_config_cuda/cuda/lib64", - "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", - ], + linkopts = select({ + "//tsl:macos": [ + "-Wl,-rpath,../local_config_cuda/cuda/lib", + "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib", + ], + "//conditions:default": [ + "-Wl,-rpath,../local_config_cuda/cuda/lib64", + "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", + ], + }), visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cudart", diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc index 2d67789d8a0017..eb8fff80bfb6ac 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker.cc @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tsl/platform/default/dso_loader.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tsl { namespace internal { namespace DsoLoader { -Status TryDlopenCUDALibraries() { +absl::Status TryDlopenCUDALibraries() { namespace CachedLoader = ::tsl::internal::CachedDsoLoader; auto cudart_status = CachedLoader::GetCudaRuntimeDsoHandle(); auto cublas_status = CachedLoader::GetCublasDsoHandle(); @@ -36,14 +35,14 @@ Status TryDlopenCUDALibraries() { !cufft_status.status().ok() || !cusolver_status.status().ok() || !cusparse_status.status().ok() || !cudnn_status.status().ok() || !cublaslt_status.status().ok()) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all CUDA libraries.")); + return absl::Status(absl::StatusCode::kInternal, + absl::StrCat("Cannot dlopen all CUDA libraries.")); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } -Status TryDlopenROCmLibraries() { +absl::Status TryDlopenROCmLibraries() { auto rocblas_status = GetRocblasDsoHandle(); auto miopen_status = GetMiopenDsoHandle(); auto rocfft_status = GetHipfftDsoHandle(); @@ -57,32 +56,30 @@ Status TryDlopenROCmLibraries() { || !hipblaslt_status.status().ok() #endif ) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all ROCm libraries.")); + return absl::InternalError("Cannot dlopen all ROCm libraries."); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } -Status MaybeTryDlopenGPULibraries() { +absl::Status MaybeTryDlopenGPULibraries() { #if GOOGLE_CUDA return TryDlopenCUDALibraries(); #elif TENSORFLOW_USE_ROCM return TryDlopenROCmLibraries(); #else LOG(INFO) << "Not built with GPU enabled. Skip GPU library dlopen check."; - return tsl::OkStatus(); + return absl::OkStatus(); #endif } -Status TryDlopenTensorRTLibraries() { +absl::Status TryDlopenTensorRTLibraries() { auto nvinfer_status = GetNvInferDsoHandle(); auto nvinferplugin_status = GetNvInferPluginDsoHandle(); if (!nvinfer_status.status().ok() || !nvinferplugin_status.status().ok()) { - return Status(absl::StatusCode::kInternal, - absl::StrCat("Cannot dlopen all TensorRT libraries.")); + return absl::InternalError("Cannot dlopen all TensorRT libraries."); } else { - return tsl::OkStatus(); + return absl::OkStatus(); } } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc index 1d4b213427b5a0..67f734302835d8 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dlopen_checker_stub.cc @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "tsl/platform/default/dso_loader.h" #include "tsl/platform/logging.h" -#include "tsl/platform/status.h" namespace tsl { namespace internal { namespace DsoLoader { // Skip check when GPU libraries are statically linked. -Status MaybeTryDlopenGPULibraries() { +absl::Status MaybeTryDlopenGPULibraries() { LOG(INFO) << "GPU libraries are statically linked, skip dlopen check."; - return ::tsl::OkStatus(); + return absl::OkStatus(); } } // namespace DsoLoader } // namespace internal diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc index fd28f05590683c..a835a81489367a 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.cc @@ -16,17 +16,18 @@ limitations under the License. #include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "third_party/gpus/cuda/cuda_config.h" #include "third_party/nccl/nccl_config.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/load_library.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "third_party/tensorrt/tensorrt_config.h" #if TENSORFLOW_USE_ROCM @@ -37,22 +38,23 @@ namespace tsl { namespace internal { namespace { -string GetCudaVersion() { return TF_CUDA_VERSION; } -string GetCudaRtVersion() { return TF_CUDART_VERSION; } -string GetCuptiVersion() { return TF_CUPTI_VERSION; } -string GetCudnnVersion() { return TF_CUDNN_VERSION; } -string GetCublasVersion() { return TF_CUBLAS_VERSION; } -string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } -string GetCufftVersion() { return TF_CUFFT_VERSION; } -string GetCusparseVersion() { return TF_CUSPARSE_VERSION; } -string GetNcclVersion() { return TF_NCCL_VERSION; } -string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } - -StatusOr GetDsoHandle(const string& name, const string& version) { - auto filename = Env::Default()->FormatLibraryFileName(name, version); +std::string GetCudaVersion() { return TF_CUDA_VERSION; } +std::string GetCudaRtVersion() { return TF_CUDART_VERSION; } +std::string GetCuptiVersion() { return TF_CUPTI_VERSION; } +std::string GetCudnnVersion() { return TF_CUDNN_VERSION; } +std::string GetCublasVersion() { return TF_CUBLAS_VERSION; } +std::string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } +std::string GetCufftVersion() { return TF_CUFFT_VERSION; } +std::string GetCusparseVersion() { return TF_CUSPARSE_VERSION; } +std::string GetNcclVersion() { return TF_NCCL_VERSION; } +std::string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } + +absl::StatusOr GetDsoHandle(const std::string& name, + const std::string& version) { + auto filename = tsl::internal::FormatLibraryFileName(name, version); void* dso_handle; - Status status = - Env::Default()->LoadDynamicLibrary(filename.c_str(), &dso_handle); + absl::Status status = + tsl::internal::LoadDynamicLibrary(filename.c_str(), &dso_handle); if (status.ok()) { VLOG(1) << "Successfully opened dynamic library " << filename; return dso_handle; @@ -60,41 +62,56 @@ StatusOr GetDsoHandle(const string& name, const string& version) { auto message = absl::StrCat("Could not load dynamic library '", filename, "'; dlerror: ", status.message()); +#if !defined(PLATFORM_WINDOWS) + if (const char* ld_library_path = getenv("LD_LIBRARY_PATH")) { + message += absl::StrCat("; LD_LIBRARY_PATH: ", ld_library_path); + } +#endif VLOG(1) << message; - return Status(absl::StatusCode::kFailedPrecondition, message); + return absl::Status(absl::StatusCode::kFailedPrecondition, message); } } // namespace namespace DsoLoader { -StatusOr GetCudaDriverDsoHandle() { +absl::StatusOr GetCudaDriverDsoHandle() { +#if defined(PLATFORM_WINDOWS) + return GetDsoHandle("nvcuda", ""); +#elif defined(__APPLE__) + // On Mac OS X, CUDA sometimes installs libcuda.dylib instead of + // libcuda.1.dylib. + auto handle_or = GetDsoHandle("cuda", ""); + if (handle_or.ok()) { + return handle_or; + } +#endif return GetDsoHandle("cuda", "1"); } -StatusOr GetCudaRuntimeDsoHandle() { +absl::StatusOr GetCudaRuntimeDsoHandle() { return GetDsoHandle("cudart", GetCudaRtVersion()); } -StatusOr GetCublasDsoHandle() { +absl::StatusOr GetCublasDsoHandle() { return GetDsoHandle("cublas", GetCublasVersion()); } -StatusOr GetCublasLtDsoHandle() { +absl::StatusOr GetCublasLtDsoHandle() { return GetDsoHandle("cublasLt", GetCublasVersion()); } -StatusOr GetCufftDsoHandle() { +absl::StatusOr GetCufftDsoHandle() { return GetDsoHandle("cufft", GetCufftVersion()); } -StatusOr GetCusolverDsoHandle() { +absl::StatusOr GetCusolverDsoHandle() { return GetDsoHandle("cusolver", GetCusolverVersion()); } -StatusOr GetCusparseDsoHandle() { +absl::StatusOr GetCusparseDsoHandle() { return GetDsoHandle("cusparse", GetCusparseVersion()); } -StatusOr GetCuptiDsoHandle() { +absl::StatusOr GetCuptiDsoHandle() { // Load specific version of CUPTI this is built. auto status_or_handle = GetDsoHandle("cupti", GetCuptiVersion()); if (status_or_handle.ok()) return status_or_handle; @@ -102,150 +119,166 @@ StatusOr GetCuptiDsoHandle() { return GetDsoHandle("cupti", ""); } -StatusOr GetCudnnDsoHandle() { +absl::StatusOr GetCudnnDsoHandle() { return GetDsoHandle("cudnn", GetCudnnVersion()); } -StatusOr GetNcclDsoHandle() { +absl::StatusOr GetNcclDsoHandle() { return GetDsoHandle("nccl", GetNcclVersion()); } -StatusOr GetNvInferDsoHandle() { +absl::StatusOr GetNvInferDsoHandle() { +#if defined(PLATFORM_WINDOWS) + return GetDsoHandle("nvinfer", ""); +#else return GetDsoHandle("nvinfer", GetTensorRTVersion()); +#endif } -StatusOr GetNvInferPluginDsoHandle() { +absl::StatusOr GetNvInferPluginDsoHandle() { +#if defined(PLATFORM_WINDOWS) + return GetDsoHandle("nvinfer_plugin", ""); +#else return GetDsoHandle("nvinfer_plugin", GetTensorRTVersion()); +#endif } -StatusOr GetRocblasDsoHandle() { return GetDsoHandle("rocblas", ""); } +absl::StatusOr GetRocblasDsoHandle() { + return GetDsoHandle("rocblas", ""); +} -StatusOr GetMiopenDsoHandle() { return GetDsoHandle("MIOpen", ""); } +absl::StatusOr GetMiopenDsoHandle() { + return GetDsoHandle("MIOpen", ""); +} -StatusOr GetHipfftDsoHandle() { return GetDsoHandle("hipfft", ""); } +absl::StatusOr GetHipfftDsoHandle() { + return GetDsoHandle("hipfft", ""); +} -StatusOr GetRocrandDsoHandle() { return GetDsoHandle("rocrand", ""); } +absl::StatusOr GetRocrandDsoHandle() { + return GetDsoHandle("rocrand", ""); +} -StatusOr GetRocsolverDsoHandle() { +absl::StatusOr GetRocsolverDsoHandle() { return GetDsoHandle("rocsolver", ""); } #if TF_ROCM_VERSION >= 40500 -StatusOr GetHipsolverDsoHandle() { +absl::StatusOr GetHipsolverDsoHandle() { return GetDsoHandle("hipsolver", ""); } #endif -StatusOr GetRoctracerDsoHandle() { +absl::StatusOr GetRoctracerDsoHandle() { return GetDsoHandle("roctracer64", ""); } -StatusOr GetHipsparseDsoHandle() { +absl::StatusOr GetHipsparseDsoHandle() { return GetDsoHandle("hipsparse", ""); } -StatusOr GetHipblasltDsoHandle() { +absl::StatusOr GetHipblasltDsoHandle() { return GetDsoHandle("hipblaslt", ""); } -StatusOr GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); } +absl::StatusOr GetHipDsoHandle() { return GetDsoHandle("amdhip64", ""); } } // namespace DsoLoader namespace CachedDsoLoader { -StatusOr GetCudaDriverDsoHandle() { +absl::StatusOr GetCudaDriverDsoHandle() { static auto result = new auto(DsoLoader::GetCudaDriverDsoHandle()); return *result; } -StatusOr GetCudaRuntimeDsoHandle() { +absl::StatusOr GetCudaRuntimeDsoHandle() { static auto result = new auto(DsoLoader::GetCudaRuntimeDsoHandle()); return *result; } -StatusOr GetCublasDsoHandle() { +absl::StatusOr GetCublasDsoHandle() { static auto result = new auto(DsoLoader::GetCublasDsoHandle()); return *result; } -StatusOr GetCublasLtDsoHandle() { +absl::StatusOr GetCublasLtDsoHandle() { static auto result = new auto(DsoLoader::GetCublasLtDsoHandle()); return *result; } -StatusOr GetCufftDsoHandle() { +absl::StatusOr GetCufftDsoHandle() { static auto result = new auto(DsoLoader::GetCufftDsoHandle()); return *result; } -StatusOr GetCusolverDsoHandle() { +absl::StatusOr GetCusolverDsoHandle() { static auto result = new auto(DsoLoader::GetCusolverDsoHandle()); return *result; } -StatusOr GetCusparseDsoHandle() { +absl::StatusOr GetCusparseDsoHandle() { static auto result = new auto(DsoLoader::GetCusparseDsoHandle()); return *result; } -StatusOr GetCuptiDsoHandle() { +absl::StatusOr GetCuptiDsoHandle() { static auto result = new auto(DsoLoader::GetCuptiDsoHandle()); return *result; } -StatusOr GetCudnnDsoHandle() { +absl::StatusOr GetCudnnDsoHandle() { static auto result = new auto(DsoLoader::GetCudnnDsoHandle()); return *result; } -StatusOr GetRocblasDsoHandle() { +absl::StatusOr GetRocblasDsoHandle() { static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); return *result; } -StatusOr GetMiopenDsoHandle() { +absl::StatusOr GetMiopenDsoHandle() { static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); return *result; } -StatusOr GetHipfftDsoHandle() { +absl::StatusOr GetHipfftDsoHandle() { static auto result = new auto(DsoLoader::GetHipfftDsoHandle()); return *result; } -StatusOr GetRocrandDsoHandle() { +absl::StatusOr GetRocrandDsoHandle() { static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); return *result; } -StatusOr GetRoctracerDsoHandle() { +absl::StatusOr GetRoctracerDsoHandle() { static auto result = new auto(DsoLoader::GetRoctracerDsoHandle()); return *result; } -StatusOr GetRocsolverDsoHandle() { +absl::StatusOr GetRocsolverDsoHandle() { static auto result = new auto(DsoLoader::GetRocsolverDsoHandle()); return *result; } #if TF_ROCM_VERSION >= 40500 -StatusOr GetHipsolverDsoHandle() { +absl::StatusOr GetHipsolverDsoHandle() { static auto result = new auto(DsoLoader::GetHipsolverDsoHandle()); return *result; } #endif -StatusOr GetHipsparseDsoHandle() { +absl::StatusOr GetHipsparseDsoHandle() { static auto result = new auto(DsoLoader::GetHipsparseDsoHandle()); return *result; } -StatusOr GetHipblasltDsoHandle() { +absl::StatusOr GetHipblasltDsoHandle() { static auto result = new auto(DsoLoader::GetHipblasltDsoHandle()); return *result; } -StatusOr GetHipDsoHandle() { +absl::StatusOr GetHipDsoHandle() { static auto result = new auto(DsoLoader::GetHipDsoHandle()); return *result; } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h index ee5b2b28af3486..6f72484d504f53 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h +++ b/third_party/xla/third_party/tsl/tsl/platform/default/dso_loader.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_DSO_LOADER_H_ #define TENSORFLOW_TSL_PLATFORM_DEFAULT_DSO_LOADER_H_ -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" namespace tsl { namespace internal { @@ -28,65 +28,65 @@ namespace internal { namespace DsoLoader { // The following methods either load the DSO of interest and return a dlopen // handle or error status. -StatusOr GetCudaDriverDsoHandle(); -StatusOr GetCudaRuntimeDsoHandle(); -StatusOr GetCublasDsoHandle(); -StatusOr GetCublasLtDsoHandle(); -StatusOr GetCufftDsoHandle(); -StatusOr GetCusolverDsoHandle(); -StatusOr GetCusparseDsoHandle(); -StatusOr GetCuptiDsoHandle(); -StatusOr GetCudnnDsoHandle(); -StatusOr GetNcclDsoHandle(); -StatusOr GetNvInferDsoHandle(); -StatusOr GetNvInferPluginDsoHandle(); +absl::StatusOr GetCudaDriverDsoHandle(); +absl::StatusOr GetCudaRuntimeDsoHandle(); +absl::StatusOr GetCublasDsoHandle(); +absl::StatusOr GetCublasLtDsoHandle(); +absl::StatusOr GetCufftDsoHandle(); +absl::StatusOr GetCusolverDsoHandle(); +absl::StatusOr GetCusparseDsoHandle(); +absl::StatusOr GetCuptiDsoHandle(); +absl::StatusOr GetCudnnDsoHandle(); +absl::StatusOr GetNcclDsoHandle(); +absl::StatusOr GetNvInferDsoHandle(); +absl::StatusOr GetNvInferPluginDsoHandle(); -StatusOr GetRocblasDsoHandle(); -StatusOr GetMiopenDsoHandle(); -StatusOr GetHipfftDsoHandle(); -StatusOr GetRocrandDsoHandle(); -StatusOr GetRoctracerDsoHandle(); -StatusOr GetRocsolverDsoHandle(); -StatusOr GetHipsolverDsoHandle(); -StatusOr GetHipsparseDsoHandle(); -StatusOr GetHipDsoHandle(); +absl::StatusOr GetRocblasDsoHandle(); +absl::StatusOr GetMiopenDsoHandle(); +absl::StatusOr GetHipfftDsoHandle(); +absl::StatusOr GetRocrandDsoHandle(); +absl::StatusOr GetRoctracerDsoHandle(); +absl::StatusOr GetRocsolverDsoHandle(); +absl::StatusOr GetHipsolverDsoHandle(); +absl::StatusOr GetHipsparseDsoHandle(); +absl::StatusOr GetHipDsoHandle(); // The following method tries to dlopen all necessary GPU libraries for the GPU // platform TF is built with (CUDA or ROCm) only when these libraries should be // dynamically loaded. Error status is returned when any of the libraries cannot // be dlopened. -Status MaybeTryDlopenGPULibraries(); +absl::Status MaybeTryDlopenGPULibraries(); // The following method tries to dlopen all necessary TensorRT libraries when // these libraries should be dynamically loaded. Error status is returned when // any of the libraries cannot be dlopened. -Status TryDlopenTensorRTLibraries(); +absl::Status TryDlopenTensorRTLibraries(); } // namespace DsoLoader // Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs // more than once. namespace CachedDsoLoader { // Cached versions of the corresponding DsoLoader methods above. -StatusOr GetCudaDriverDsoHandle(); -StatusOr GetCudaRuntimeDsoHandle(); -StatusOr GetCublasDsoHandle(); -StatusOr GetCublasLtDsoHandle(); -StatusOr GetCufftDsoHandle(); -StatusOr GetCusolverDsoHandle(); -StatusOr GetCusparseDsoHandle(); -StatusOr GetCuptiDsoHandle(); -StatusOr GetCudnnDsoHandle(); +absl::StatusOr GetCudaDriverDsoHandle(); +absl::StatusOr GetCudaRuntimeDsoHandle(); +absl::StatusOr GetCublasDsoHandle(); +absl::StatusOr GetCublasLtDsoHandle(); +absl::StatusOr GetCufftDsoHandle(); +absl::StatusOr GetCusolverDsoHandle(); +absl::StatusOr GetCusparseDsoHandle(); +absl::StatusOr GetCuptiDsoHandle(); +absl::StatusOr GetCudnnDsoHandle(); -StatusOr GetRocblasDsoHandle(); -StatusOr GetMiopenDsoHandle(); -StatusOr GetHipfftDsoHandle(); -StatusOr GetRocrandDsoHandle(); -StatusOr GetRocsolverDsoHandle(); -StatusOr GetHipsolverDsoHandle(); -StatusOr GetRoctracerDsoHandle(); -StatusOr GetHipsparseDsoHandle(); -StatusOr GetHipblasltDsoHandle(); -StatusOr GetHipDsoHandle(); +absl::StatusOr GetRocblasDsoHandle(); +absl::StatusOr GetMiopenDsoHandle(); +absl::StatusOr GetHipfftDsoHandle(); +absl::StatusOr GetRocrandDsoHandle(); +absl::StatusOr GetRocsolverDsoHandle(); +absl::StatusOr GetHipsolverDsoHandle(); +absl::StatusOr GetRoctracerDsoHandle(); +absl::StatusOr GetHipsparseDsoHandle(); +absl::StatusOr GetHipblasltDsoHandle(); +absl::StatusOr GetHipDsoHandle(); } // namespace CachedDsoLoader } // namespace internal diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc b/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc index f49adf2f7f257d..70961c8dc990ef 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/load_library.cc @@ -17,26 +17,26 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include + +#include "absl/status/status.h" namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle) { +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) { *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL); if (!*handle) { // Note that in C++17 std::string_view(nullptr) gives segfault! const char* error_msg = dlerror(); - return tsl::errors::NotFound(error_msg ? error_msg - : "(null error message)"); + return absl::NotFoundError(error_msg ? error_msg : "(null error message)"); } - return OkStatus(); + return absl::OkStatus(); } -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) { +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { // Check that the handle is not NULL to avoid dlsym's RTLD_DEFAULT behavior. if (!handle) { *symbol = nullptr; @@ -46,14 +46,14 @@ Status GetSymbolFromLibrary(void* handle, const char* symbol_name, if (!*symbol) { // Note that in C++17 std::string_view(nullptr) gives segfault! const char* error_msg = dlerror(); - return tsl::errors::NotFound(error_msg ? error_msg - : "(null error message)"); + return absl::NotFoundError(error_msg ? error_msg : "(null error message)"); } - return OkStatus(); + return absl::OkStatus(); } -string FormatLibraryFileName(const string& name, const string& version) { - string filename; +std::string FormatLibraryFileName(const std::string& name, + const std::string& version) { + std::string filename; #if defined(__APPLE__) if (version.size() == 0) { filename = "lib" + name + ".dylib"; diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc index c2151c78ec5330..868fb35f887dab 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc @@ -15,6 +15,7 @@ limitations under the License. #include "absl/base/internal/sysinfo.h" #include "tsl/platform/cpu_info.h" +#include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/numa.h" @@ -256,7 +257,6 @@ int NUMAGetThreadNodeAffinity() { return node_index; } - void* NUMAMalloc(int node, size_t size, int minimum_alignment) { #ifdef TENSORFLOW_USE_NUMA if (HaveHWLocTopology()) { @@ -307,7 +307,6 @@ int NUMAGetMemAffinity(const void* addr) { return node; } - bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -447,5 +446,8 @@ MemoryBandwidthInfo GetMemoryBandwidthInfo() { MemoryBandwidthInfo membw_info = {INT64_MAX}; return membw_info; } + +IOStatistics GetIOStatistics() { return IOStatistics(); } + } // namespace port } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc b/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc index d750328ebf38fd..c786295c08e0e9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/subprocess.cc @@ -30,7 +30,11 @@ limitations under the License. #include "tsl/platform/logging.h" // Android versions older than 28 do not have posix_spawn(). -#define USE_POSIX_SPAWN !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 +#if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 +#define USE_POSIX_SPAWN 1 +#else // defined(__ANDROID_API__) && __ANDROID_API__ < 28 +#define USE_POSIX_SPAWN 0 +#endif // !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 // 1) FYI from m3b@ about fork(): // A danger of calling fork() (as opposed to clone() or vfork()) is that if diff --git a/third_party/xla/third_party/tsl/tsl/platform/denormal.cc b/third_party/xla/third_party/tsl/tsl/platform/denormal.cc index 4f071109c32abd..9d65ddc68fda85 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/denormal.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/denormal.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tsl/platform/denormal.h" +#include + #include "tsl/platform/cpu_info.h" #include "tsl/platform/platform.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.h b/third_party/xla/third_party/tsl/tsl/platform/file_system.h index 76fab57f4b64b5..8f7bd875e35bc3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -40,6 +41,7 @@ limitations under the License. namespace tsl { +class FileAcl; class RandomAccessFile; class ReadOnlyMemoryRegion; class WritableFile; @@ -531,6 +533,13 @@ class FileSystem { return errors::Unimplemented("SetOption"); } + /// \brief Set File System ACL checker. + /// + /// No checks are enforced if a FileAcl is never set. + virtual tsl::Status SetFileAcl(std::shared_ptr file_acl) { + return errors::Unimplemented("SetFileAcl"); + } + FileSystem() {} virtual ~FileSystem() = default; @@ -902,6 +911,13 @@ class FileSystemRegistry { std::vector* schemes) = 0; }; +/// \brief An abstraction for enforcing ACL checks in FileSystem. +class FileAcl { + public: + virtual absl::Status CheckAccess(std::string_view path) = 0; + virtual ~FileAcl() = default; +}; + } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/host_info.h b/third_party/xla/third_party/tsl/tsl/platform/host_info.h index 189f3be2934ce3..630f9424525e04 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/host_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/host_info.h @@ -16,11 +16,26 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ #define TENSORFLOW_TSL_PLATFORM_HOST_INFO_H_ +#include + #include "tsl/platform/types.h" namespace tsl { namespace port { +// Statistical data of IO operations performed by the job. +struct IOStatistics { + struct Distribution { + uint64_t count = 0; + double mean = 0.0; + double std_dev = 0.0; + }; + // Distribution of round trip IO latency in microseconds. + Distribution roundtrip_latency_usec; + // Distribution of data received by IO reads in bytes. + Distribution response_bytes; +}; + // Return the hostname of the machine on which this process is running. string Hostname(); @@ -34,6 +49,9 @@ int64_t JobUid(); // Returns the Borg task ID as an int64_t if it exists. Otherwise return -1. int64_t TaskId(); +// Retrieves the host file read statistics. +IOStatistics GetIOStatistics(); + } // namespace port } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/load_library.h b/third_party/xla/third_party/tsl/tsl/platform/load_library.h index e46f85da0a7f9a..5a42f2a3439fd0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/load_library.h +++ b/third_party/xla/third_party/tsl/tsl/platform/load_library.h @@ -16,16 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_LOAD_LIBRARY_H_ #define TENSORFLOW_TSL_PLATFORM_LOAD_LIBRARY_H_ -#include "tsl/platform/status.h" +#include + +#include "absl/status/status.h" namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle); -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol); -string FormatLibraryFileName(const string& name, const string& version); +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle); +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol); +std::string FormatLibraryFileName(const std::string& name, + const std::string& version); } // namespace internal diff --git a/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD b/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD index 1c04a4d2d4a1a7..1c5558dbb9faeb 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD @@ -44,10 +44,10 @@ cc_library( srcs = [ "android_armv7a_cpu_utils_helper.h", "cpu_utils.cc", - "i_cpu_utils_helper.h", ], hdrs = [ "cpu_utils.h", + "i_cpu_utils_helper.h", ], copts = tsl_copts(), visibility = ["//visibility:public"], diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h index bddf2529771f1e..ee2144dca8a698 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h @@ -98,19 +98,12 @@ limitations under the License. // Status status = OkStatus(); // EXPECT_THAT(status, IsOk()); -namespace tensorflow { -namespace error { -// TODO(ddunleavy) Move this to TSL. This stays here until error_codes proto -// is moved to TSL due to an ADL issue +namespace tsl { + inline void PrintTo(const tsl::error::Code code, std::ostream* os) { *os << Code_Name(code); } -} // namespace error -} // namespace tensorflow - -namespace tsl { - template void PrintTo(const StatusOr& status_or, std::ostream* os) { *os << ::testing::PrintToString(status_or.status()); diff --git a/third_party/xla/third_party/tsl/tsl/platform/tensor_float_32_utils.h b/third_party/xla/third_party/tsl/tsl/platform/tensor_float_32_utils.h index 5d1db659c9f43c..d956340c303309 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/tensor_float_32_utils.h +++ b/third_party/xla/third_party/tsl/tsl/platform/tensor_float_32_utils.h @@ -18,6 +18,8 @@ limitations under the License. namespace tsl { +// NOTE: The usage of this function is only supported through the Tensorflow +// framework. void enable_tensor_float_32_execution(bool enabled); bool tensor_float_32_execution_enabled(); diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD index 7ff0f110fe6722..2bde9eb95b3b73 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD @@ -1,10 +1,9 @@ -load("//tsl:tsl.default.bzl", "filegroup") - # Tensorflow windows-specific implementations of tensorflow/core/platform libraries. load( "//tsl:tsl.bzl", "tsl_copts", ) +load("//tsl:tsl.default.bzl", "filegroup") load( "//tsl/platform:rules_cc.bzl", "cc_library", @@ -144,7 +143,7 @@ cc_library( deps = [ ":wide_char", "//tsl/platform:errors", - "//tsl/platform:status", + "@com_google_absl//absl/status", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc index 0c47532dc687a7..66d2d62cf6e130 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/load_library.cc @@ -28,7 +28,7 @@ limitations under the License. #include #include -#include "tsl/platform/errors.h" +#include "absl/status/status.h" #include "tsl/platform/windows/wide_char.h" #pragma comment(lib, "Shlwapi.lib") @@ -37,8 +37,8 @@ namespace tsl { namespace internal { -Status LoadDynamicLibrary(const char* library_filename, void** handle) { - string file_name = library_filename; +absl::Status LoadDynamicLibrary(const char* library_filename, void** handle) { + std::string file_name = library_filename; std::replace(file_name.begin(), file_name.end(), '/', '\\'); std::wstring ws_file_name(tsl::Utf8ToWideChar(file_name)); @@ -46,26 +46,27 @@ Status LoadDynamicLibrary(const char* library_filename, void** handle) { HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH); if (!hModule) { - return tsl::errors::NotFound(file_name + " not found"); + return absl::NotFoundError(file_name + " not found"); } *handle = hModule; - return OkStatus(); + return absl::OkStatus(); } -Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) { +absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) { FARPROC found_symbol; found_symbol = GetProcAddress((HMODULE)handle, symbol_name); if (found_symbol == NULL) { - return tsl::errors::NotFound(std::string(symbol_name) + " not found"); + return absl::NotFoundError(std::string(symbol_name) + " not found"); } *symbol = (void**)found_symbol; - return OkStatus(); + return absl::OkStatus(); } -string FormatLibraryFileName(const string& name, const string& version) { - string filename; +std::string FormatLibraryFileName(const std::string& name, + const std::string& version) { + std::string filename; if (version.size() == 0) { filename = name + ".dll"; } else { diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc index 9b5692650dbb5c..f8e19503edb305 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc @@ -61,6 +61,8 @@ int64_t JobUid() { return -1; } int64_t TaskId() { return -1; } +IOStatistics GetIOStatistics() { return IOStatistics(); } + int NumSchedulableCPUs() { SYSTEM_INFO system_info; GetSystemInfo(&system_info); @@ -122,7 +124,6 @@ void NUMAFree(void* ptr, size_t size) { tsl::port::Free(ptr); } int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; } - bool Snappy_Compress(const char* input, size_t length, string* output) { #ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); @@ -183,7 +184,7 @@ string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { DWORD data; DWORD data_size = sizeof(data); - #pragma comment(lib, "shlwapi.lib") // For SHGetValue(). +#pragma comment(lib, "shlwapi.lib") // For SHGetValue(). if (SUCCEEDED( SHGetValueA(HKEY_LOCAL_MACHINE, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index 70fe322adcda52..c23d63f5f4eddd 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -3,6 +3,10 @@ load("//tsl/platform:build_config_root.bzl", "if_static") load("//tsl:tsl.default.bzl", "filegroup") load("//tsl:tsl.bzl", "if_not_android", "set_external_visibility") load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load( + "//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) load( "//tsl/profiler/builds:build_config.bzl", "tf_profiler_copts", @@ -252,8 +256,8 @@ cc_library( "//tsl/platform:macros", "//tsl/platform:types", "@com_google_absl//absl/strings", - ] + if_not_android([ - "//tsl/profiler/backends/cpu:annotation_stack", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", # NVTX headers ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h index 416d8293784551..e3eaaa08af79e8 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.h @@ -24,18 +24,17 @@ limitations under the License. #if GOOGLE_CUDA #include "nvtx3/nvToolsExt.h" +#else +// Some typedef to help build without NVTX. +typedef void* nvtxEventAttributes_t; +typedef void* nvtxDomainHandle_t; +typedef void* nvtxStringHandle_t; #endif namespace tsl { namespace profiler { namespace nvtx { -// Some typedef to help build without NVTX. -#if !GOOGLE_CUDA -typedef void* nvtxEventAttributes_t; -typedef void* nvtxDomainHandle_t; -#endif - // A helper function that return the domains to use if NVTX profiling // is enabled. inline std::optional GetNVTXDomain() { @@ -65,15 +64,38 @@ inline bool RangesEnabled() { #endif } -// Note: The memory backing msg must persist until the result of this function -// has been consumed by an NVTX API. -inline void MakeAttributes(const char* msg, nvtxEventAttributes_t* result) { - *result = {0}; +// Two types of NVTX range annotation are supported, the older/simpler option +// is to use std::string and have the NVTX implementation copy a C-style +// string every time. The other option is to pass a struct implementing two +// methods: +// +// std::string_view Title() const; +// nvtxStringHandle_t NvtxRegisteredTitle() const; +// +// in which case NvtxRegisteredTitle() will be used when starting NVTX ranges, +// avoiding this string copy. +// The Title() method is needed because AnnotationStack::PushAnnotation(...) is +// the backend for some annotations when NVTX is not enabled, and it does not +// recognise registered strings. has_annotation_api_v +// distinguishes between the two types of annotation. +template +inline constexpr bool has_annotation_api_v = + !std::is_same_v; + +template +void RangePush(nvtxDomainHandle_t domain, const AnnotationType& annotation) { #if GOOGLE_CUDA - result->version = NVTX_VERSION; - result->size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; - result->messageType = NVTX_MESSAGE_TYPE_ASCII; - result->message.ascii = msg; + nvtxEventAttributes_t attrs{}; + attrs.version = NVTX_VERSION; + attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + if constexpr (has_annotation_api_v>) { + attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED; + attrs.message.registered = annotation.NvtxRegisteredTitle(); + } else { + attrs.messageType = NVTX_MESSAGE_TYPE_ASCII; + attrs.message.ascii = annotation.c_str(); + } + ::nvtxDomainRangePushEx(domain, &attrs); #endif } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h index 643d7045428605..f047fafc4ebe3a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h @@ -53,10 +53,7 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), std::string{name}); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -74,9 +71,7 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -91,9 +86,7 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -109,15 +102,17 @@ class ScopedAnnotationT { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - auto name = name_generator(); - nvtxEventAttributes_t attrs; - tsl::profiler::nvtx::MakeAttributes(name.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name_generator()); } else // NOLINT #endif if (always_annotate || TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - auto name = name_generator(); - old_length_ = AnnotationStack::PushAnnotation(name); + auto annotation = name_generator(); + if constexpr (tsl::profiler::nvtx::has_annotation_api_v< + std::decay_t>) { + old_length_ = AnnotationStack::PushAnnotation(annotation.Title()); + } else { + old_length_ = AnnotationStack::PushAnnotation(std::move(annotation)); + } } #endif } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h index f4e538f127c9bb..db46f7c99135e4 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_stack.h @@ -55,10 +55,7 @@ class ScopedAnnotationStack { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name); } else // NOLINT #endif if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -83,15 +80,17 @@ class ScopedAnnotationStack { std::optional domain = tsl::profiler::nvtx::GetNVTXDomain(); if (TF_PREDICT_FALSE(domain.has_value())) { - auto name = name_generator(); - nvtxEventAttributes_t attrs; - std::string name_str(name); - tsl::profiler::nvtx::MakeAttributes(name_str.c_str(), &attrs); - ::nvtxDomainRangePushEx(domain.value(), &attrs); + tsl::profiler::nvtx::RangePush(domain.value(), name_generator()); } else // NOLINT #endif if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { - return AnnotationStack::PushAnnotation(name_generator()); + auto annotation = name_generator(); + if constexpr (tsl::profiler::nvtx::has_annotation_api_v< + std::decay_t>) { + return AnnotationStack::PushAnnotation(annotation.Title()); + } else { + return AnnotationStack::PushAnnotation(std::move(annotation)); + } } #endif return kInvalidActivity; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc index 7dadfae46f7913..4129e2ae8fa7c7 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tsl/profiler/utils/tf_op_utils.h" +#include +#include #include #include @@ -61,6 +63,32 @@ absl::string_view DeriveOpType(absl::string_view full_op_name) { return op_type; } +// TODO(xprof-devs): Include the corresponding Ops on TPU. +std::optional GetMemcpyOp(absl::string_view tf_op_fullname) { + TfOp tf_op; + tf_op.name = tf_op_fullname; + if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) { + tf_op.category = Category::kMemcpyHToD; + tf_op.type = kMemcpyHToDOp; + return tf_op; + } + if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) { + tf_op.category = Category::kMemcpyDToH; + tf_op.type = kMemcpyDToHOp; + return tf_op; + } + if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToD")) { + tf_op.category = Category::kMemcpyDToD; + tf_op.type = kMemcpyDToDOp; + return tf_op; + } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToH")) { + tf_op.category = Category::kMemcpyHToH; + tf_op.type = kMemcpyHToHOp; + return tf_op; + } + return std::nullopt; +} + } // namespace const absl::string_view kUnknownOp = ""; // op types are non-empty strings @@ -70,12 +98,14 @@ const absl::string_view kMemcpyDToHOp = "MemcpyDToH"; const absl::string_view kMemcpyDToDOp = "MemcpyDToD"; const absl::string_view kMemcpyHToHOp = "MemcpyHToH"; +// Example inputs: "MyOpName", "MyNamespace>MyOpName" bool IsTfOpName(absl::string_view op_name) { // TODO(b/177602927): Confirm the naming convention with the TF team. static const LazyRE2 kTfOpNameRegEx = {"[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*"}; return RE2::FullMatch(op_name, *kTfOpNameRegEx); } +// Example inputs: "MyType", "_MyInternalType" bool IsTfOpType(absl::string_view op_type) { static const LazyRE2 kTfOpTypeRegEx = {"[A-Z_][a-zA-Z0-9_]*"}; return RE2::FullMatch(op_type, *kTfOpTypeRegEx); @@ -97,52 +127,64 @@ bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type) { } TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { - // TF Op names have the format "name:type". + // For op types below, they all have the format ":", though + // op_type could be empty. TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp}; std::vector parts = absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1)); + if (parts.size() != 2) { - // GPU-related Ops that need to be tracked. - if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) { - tf_op.category = Category::kMemcpyHToD; - tf_op.type = kMemcpyHToDOp; - } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) { - tf_op.category = Category::kMemcpyDToH; - tf_op.type = kMemcpyDToHOp; - } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToD")) { - tf_op.category = Category::kMemcpyDToD; - tf_op.type = kMemcpyDToDOp; - } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToH")) { - tf_op.category = Category::kMemcpyHToH; - tf_op.type = kMemcpyHToHOp; + // Two possibilities here: GPU memcpy op or invalid op. + if (std::optional tfop = GetMemcpyOp(parts[0]); tfop.has_value()) { + return *tfop; } - // TODO(ckluk): Include the corresponding Ops on TPU. - } else if (parts[0] == kIterator) { + return tf_op; + } + + // Check for a Dataset op. + if (parts[0] == kIterator) { // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the // format of TF Op names. But we still want to capture them for // input-pipeline analysis. tf_op.category = Category::kTfData; tf_op.type = kDatasetOp; - } else if (IsTfOpType(parts[1]) && IsTfOpName(parts[0])) { - tf_op = {Category::kTensorFlow, parts[0], parts[1]}; - } else { - absl::string_view op_type = - parts[1].empty() ? DeriveOpType(parts[0]) : parts[1]; - if (IsJaxOpType(op_type)) { - // JAX category introduces op_type with '[]' including unnecessary details - // to represent a group of ops. - // We need to striping the brackets and contents inside. Based on our - // analysis, all the op_type ends with a closing ']' if it contains - // brakets. It's safe to remove all the characters starting with the - // position of '['. - // Example: - // "transpose[permutation=(0, 3, 1, 2)]" => "transpose" - // See: go/xprof-jax-op-type - tf_op = {Category::kJax, parts[0], op_type.substr(0, op_type.find('['))}; - } else if (parts[1].empty()) { - tf_op = {Category::kTensorFlow, parts[0], op_type}; - } + return tf_op; + } + + // Check for Tensorflow Op. + if (IsTfOpName(parts[0]) && IsTfOpType(parts[1])) { + tf_op.category = Category::kTensorFlow; + tf_op.name = parts[0]; + tf_op.type = parts[1]; + return tf_op; + } + + // Check for JAX op. + absl::string_view op_type = + parts[1].empty() ? DeriveOpType(parts[0]) : parts[1]; + if (IsJaxOpType(op_type)) { + // JAX category introduces op_type with '[]' including unnecessary details + // to represent a group of ops. + // We need to striping the brackets and contents inside. Based on our + // analysis, all the op_type ends with a closing ']' if it contains + // brakets. It's safe to remove all the characters starting with the + // position of '['. + // Example: + // "transpose[permutation=(0, 3, 1, 2)]" => "transpose" + // See: go/xprof-jax-op-type + tf_op.category = Category::kJax; + tf_op.name = parts[0]; + tf_op.type = op_type.substr(0, op_type.find('[')); + return tf_op; + } + + if (parts[1].empty()) { + tf_op.category = Category::kTensorFlow; + tf_op.name = parts[0]; + tf_op.type = op_type; + return tf_op; } + return tf_op; } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h index 90cee796fd95a7..6a7093b422c7d1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h @@ -29,10 +29,15 @@ namespace profiler { // Support up to 500 accelerator devices. constexpr uint32 kFirstDeviceId = 1; constexpr uint32 kLastDeviceId = 500; -// Support Upto 200 custom planes. -constexpr uint32 kCustomPlaneDeviceId = kLastDeviceId + 1; +// Support Upto 200 custom planes as fake devices (i.e., planes with a +// "/custom:" prefix). See `::kCustomPlanePrefix` for more +// information +constexpr uint32 kFirstCustomPlaneDeviceId = kLastDeviceId + 1; +constexpr uint32 kMaxCustomPlaneDevicesPerHost = 200; +constexpr uint32 kLastCustomPlaneDeviceId = + kFirstCustomPlaneDeviceId + kMaxCustomPlaneDevicesPerHost - 1; // Host threads are shown as a single fake device. -constexpr uint32 kHostThreadsDeviceId = kCustomPlaneDeviceId + 200; +constexpr uint32 kHostThreadsDeviceId = kLastCustomPlaneDeviceId + 1; // Constants used as trace_viewer TID (resource_id in trace_events.proto). constexpr int kThreadIdDerivedMin = 0xdeadbeef; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 2f7eb630aa324a..62b69f2910b334 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -272,6 +272,7 @@ const StatTypeMap& GetStatTypeMap() { {"model_version", kModelVersion}, {"bytes_transferred", kBytesTransferred}, {"queue", kDmaQueue}, + {"dcn_collective_info", kDcnCollectiveInfo}, // Performance counter related. {"Raw Value", kRawValue}, {"Scaled Value", kScaledValue}, diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 8fa320791f0ee5..7bbd052f815eb9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -259,6 +259,7 @@ enum StatType { kModelVersion, kBytesTransferred, kDmaQueue, + kDcnCollectiveInfo, // Performance counter related. kRawValue, kScaledValue, diff --git a/third_party/xla/third_party/tsl/tsl/tsl.bzl b/third_party/xla/third_party/tsl/tsl/tsl.bzl index e1c2b364fbca78..caad0d8eeb5a20 100644 --- a/third_party/xla/third_party/tsl/tsl/tsl.bzl +++ b/third_party/xla/third_party/tsl/tsl/tsl.bzl @@ -37,6 +37,11 @@ load( "if_tensorrt", ) +# buildifier: disable=out-of-order-load +# Internally this loads a macro, but in OSS this is a function +def register_extension_info(**kwargs): + pass + two_gpu_tags = ["requires-gpu-nvidia:2", "notap", "manual", "no_pip"] def clean_dep(target): @@ -349,6 +354,8 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs **kwargs ) +register_extension_info(extension = tsl_gpu_library, label_regex_for_dep = "{extension_name}") + # Traverse the dependency graph along the "deps" attribute of the # target and return a struct with one field called 'tf_collected_deps'. # tf_collected_deps will be the union of the deps of the current target @@ -562,6 +569,7 @@ def tsl_pybind_extension_opensource( data = [], defines = [], deprecation = None, + enable_stub_generation = False, # @unused features = [], licenses = None, linkopts = [], @@ -754,9 +762,6 @@ def tsl_pybind_extension_opensource( compatible_with = compatible_with, ) -# Export open source version of pybind_extension under base name as well. -tsl_pybind_extension = tsl_pybind_extension_opensource - # Used for specifying external visibility constraints. In non-monorepo situations, this needs to be # public, but monorepos can have more precise constraints. def set_external_visibility(monorepo_paths): diff --git a/third_party/xla/third_party/tsl/tsl/tsl.default.bzl b/third_party/xla/third_party/tsl/tsl/tsl.default.bzl index c6bb4f3526e9b9..1759e5106320d5 100644 --- a/third_party/xla/third_party/tsl/tsl/tsl.default.bzl +++ b/third_party/xla/third_party/tsl/tsl/tsl.default.bzl @@ -7,7 +7,7 @@ load( _if_not_mobile_or_arm_or_lgpl_restricted = "if_not_mobile_or_arm_or_lgpl_restricted", _internal_hlo_deps = "internal_hlo_deps", _tsl_grpc_cc_dependencies = "tsl_grpc_cc_dependencies", - _tsl_pybind_extension = "tsl_pybind_extension", + _tsl_pybind_extension = "tsl_pybind_extension_opensource", ) get_compatible_with_portable = _get_compatible_with_portable diff --git a/third_party/xla/third_party/tsl/tsl/util/BUILD b/third_party/xla/third_party/tsl/tsl/util/BUILD index a913dbb77ac724..09b864ac264d20 100644 --- a/third_party/xla/third_party/tsl/tsl/util/BUILD +++ b/third_party/xla/third_party/tsl/tsl/util/BUILD @@ -286,6 +286,7 @@ cc_library( "//tsl/platform:stringpiece", "//tsl/platform:stringprintf", "//tsl/platform:types", + "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc index 520962fe410262..5e316e9ae9fc6a 100644 --- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc +++ b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tsl/util/command_line_flags.h" +#include #include #include #include #include +#include "absl/strings/match.h" #include "tsl/platform/logging.h" #include "tsl/platform/str_util.h" #include "tsl/platform/stringpiece.h" @@ -96,10 +98,10 @@ bool ParseBoolFlag(StringPiece arg, StringPiece flag, if (!absl::ConsumePrefix(&arg, "=")) { return false; } - if (absl::EqualsIgnoreCase(arg, "true")) { + if (absl::EqualsIgnoreCase(arg, "true") || arg == "1") { *value_parsing_ok = hook(true); return true; - } else if (absl::EqualsIgnoreCase(arg, "false")) { + } else if (absl::EqualsIgnoreCase(arg, "false") || arg == "0") { *value_parsing_ok = hook(false); return true; } else { @@ -290,6 +292,29 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { return result && (*argc < 2 || strcmp(argv[1], "--help") != 0); } +/*static*/ bool Flags::Parse(std::vector& flags, + const std::vector& flag_list) { + bool result = true; + std::vector unknown_flags; + for (auto& flag : flags) { + for (const Flag& flag_object : flag_list) { + bool value_parsing_ok; + bool was_found = flag_object.Parse(flag, &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + // Clear parsed flags, these empty entries are removed later. + if (was_found) { + flag.clear(); + break; + } + } + } + auto IsEmpty = [](const std::string& flag) { return flag.empty(); }; + flags.erase(std::remove_if(flags.begin(), flags.end(), IsEmpty), flags.end()); + return result; +} + /*static*/ string Flags::Usage(const string& cmdline, const std::vector& flag_list) { string usage_text; diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h index 6553bc887c853e..2710de5753cd01 100644 --- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h +++ b/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h @@ -132,6 +132,11 @@ class Flags { // first remaining argument is not "--help". static bool Parse(int* argc, char** argv, const std::vector& flag_list); + // Similar as above, but accepts a mutable vector of strings in place of + // argc and argv. Doesn't ignore the first flag, and return the unknown flags + // back in flags vector. + static bool Parse(std::vector& flags, + const std::vector& flag_list); // Return a usage message with command line cmdline, and the // usage_text strings in flag_list[]. static string Usage(const string& cmdline, diff --git a/third_party/xla/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD new file mode 100644 index 00000000000000..dc621893ac9675 --- /dev/null +++ b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD @@ -0,0 +1,191 @@ +"""Toolchain configs for cross-compiling TensorFlow""" + +load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +cc_toolchain_suite( + name = "cross_compile_toolchain_suite", + toolchains = { + "aarch64": ":linux_aarch64_toolchain", + "k8": ":linux_x86_toolchain", + }, +) + +filegroup( + name = "empty", + visibility = ["//visibility:public"], +) + +cc_toolchain( + name = "linux_x86_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_x86_toolchain_config", + toolchain_identifier = "linux_x86_toolchain", +) + +cc_toolchain_config( + name = "linux_x86_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt9", + compile_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mavx", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "k8", + cxx_builtin_include_directories = [ + "/dt9/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "x86_64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) + +cc_toolchain( + name = "linux_aarch64_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_aarch64_toolchain_config", + toolchain_identifier = "linux_aarch64_toolchain", +) + +cc_toolchain_config( + name = "linux_aarch64_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt10/", + compile_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mtune=generic", + "-march=armv8-a", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "aarch64", + cxx_builtin_include_directories = [ + "/dt10/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "aarch64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_aarch64_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/third_party/xla/tools/toolchains/cross_compile/config/BUILD b/third_party/xla/tools/toolchains/cross_compile/config/BUILD new file mode 100644 index 00000000000000..b6a504ba1449d6 --- /dev/null +++ b/third_party/xla/tools/toolchains/cross_compile/config/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + exec_properties = { + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "OSFamily": "Linux", + }, +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 4554463cb90675..4b07fb5c18670d 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -200,6 +200,28 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) + tensorflow_rbe_config( + name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/usr/lib/llvm-17/bin/clang", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + sysroot = "/dt9", + python_install_path = "/usr/local", + ) + + tensorflow_rbe_config( + name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", + compiler = "/dt9/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "12.3", + cudnn_version = "8.9", + os = "ubuntu20.04-manylinux2014-multipython", + python_versions = ["3.9", "3.10", "3.11", "3.12"], + python_install_path = "/usr/local", + ) + tensorflow_rbe_win_config( name = "windows_py37", python_bin_path = "C:/Python37/python.exe", diff --git a/third_party/xla/tools/toolchains/remote_config/containers.bzl b/third_party/xla/tools/toolchains/remote_config/containers.bzl index bfb4634e810328..cd346c2816def1 100644 --- a/third_party/xla/tools/toolchains/remote_config/containers.bzl +++ b/third_party/xla/tools/toolchains/remote_config/containers.bzl @@ -5,8 +5,9 @@ container_digests = { # TF now uses only this container "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": "sha256:48612bd85709cd014711d0b0f87e0806f3567d06d2e81c6e860516b87498b821", # JAX manylinux2014 configs. - "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:ab39410baf2fc1d31d50540acec7640d7f4814fa694e2421b696b6f0a058d645", - "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:b699d6ae235ac601dc3e62391ac7c4606cb10331f8141983858c1580f5e74ddb", + "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:b112c0c77d4172fc025420938f13ea83f3ad480c01778e743a201e5e3f4710e1", + "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:afe68c3448734cb07b16005fd9ed47d19533eb8bf5acd92863735ce24766b93b", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", @@ -98,6 +99,13 @@ containers = { "digest": container_digests["cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython. + "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython", + "digest": container_digests["cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu18.04-manylinux2010-multipython. "rocm-ubuntu18.04-manylinux2010-multipython": { "registry": "gcr.io", diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 31cb19540a7020..2221985b7bd3c1 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -9,12 +9,14 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # Import third party repository rules. See go/tfbr-thirdparty. load("//third_party/dlpack:workspace.bzl", dlpack = "repo") +load("//third_party/gloo:workspace.bzl", gloo = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") load("//third_party/triton:workspace.bzl", triton = "repo") def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ dlpack() + gloo() stablehlo() triton() @@ -37,6 +39,14 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.9.zip"), ) + tf_http_archive( + name = "cutlass_archive", + build_file = "//third_party:cutlass.BUILD", + sha256 = "ea1b7f96919460a5d80b09c1b246652539a8605600b2be4cccc02c254bccbe50", + strip_prefix = "cutlass-5783d6dbd0c34032371cce2bd999fc76007520d7", + urls = tf_mirror_urls("https://github.com/chsigg/cutlass/archive/5783d6dbd0c34032371cce2bd999fc76007520d7.tar.gz"), + ) + tf_http_archive( name = "boringssl", sha256 = "9dc53f851107eaf87b391136d13b815df97ec8f76dadb487b58b2fc45e624d2c", diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 24100c8dbcdc5d..caf77930363679 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -470,6 +470,7 @@ xla_cc_test( ":shape_util", ":test", ":xla_data_proto_cc", + "//xla:status", "@com_google_absl//absl/hash:hash_testing", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc index 864d98a8269082..3b89c3b6054de1 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.cc +++ b/third_party/xla/xla/backends/interpreter/compiler.cc @@ -49,7 +49,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor_pimpl.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/backends/interpreter/executor.cc b/third_party/xla/xla/backends/interpreter/executor.cc index 3766f7cb7af82c..1095c71a86b226 100644 --- a/third_party/xla/xla/backends/interpreter/executor.cc +++ b/third_party/xla/xla/backends/interpreter/executor.cc @@ -34,12 +34,6 @@ DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64_t size, return DeviceMemoryBase(new char[size], size); } -void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent, - uint64_t offset_bytes, - uint64_t /*size_bytes*/) { - return parent + offset_bytes; -} - void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { delete[] static_cast(mem->opaque()); } diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 358b609f23020d..5d866462950072 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -48,22 +48,22 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { XlaInterpreterExecutor() = default; tsl::Status Init(int device_ordinal, DeviceOptions device_options) override { + device_ordinal_ = device_ordinal; return ::tsl::OkStatus(); } + int device_ordinal() const override { return device_ordinal_; }; tsl::Status GetKernel(const MultiKernelLoaderSpec &spec, - KernelBase *kernel) override { + Kernel *kernel) override { return tsl::errors::Unimplemented("Not Implemented"); } tsl::Status Launch(Stream *stream, const ThreadDim &thread_dims, - const BlockDim &block_dims, const KernelBase &kernel, - const KernelArgsArrayBase &args) override { + const BlockDim &block_dims, const Kernel &kernel, + const KernelArgs &args) override { return tsl::errors::Unimplemented("Not Implemented"); } DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - void *GetSubBuffer(DeviceMemoryBase *parent, uint64_t offset_bytes, - uint64_t size_bytes) override; void Deallocate(DeviceMemoryBase *mem) override; void *HostMemoryAllocate(uint64_t size) override { return new char[size]; } @@ -182,6 +182,10 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } private: + // The device ordinal value that this executor was initialized with; recorded + // for use in getting device metadata. Immutable post-initialization. + int device_ordinal_; + DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); tsl::StatusOr AllocateOutputBuffer(const xla::Shape &shape); diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index ca9d8a952a26b9..50f033f79e6b3a 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -105,14 +105,19 @@ cc_library( "//xla:debug_options_flags", "//xla:execution_options_util", "//xla:shape_util", + "//xla:statusor", + "//xla:util", "//xla:xla_proto_cc", "//xla/pjrt:compile_options_proto_cc", "//xla/service:compilation_environments", "//xla/service:computation_placer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) @@ -301,13 +306,17 @@ xla_cc_test( ":xla_computation", "//xla:debug_options_flags", "//xla:shape_util", + "//xla:statusor", + "//xla:test", "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc index 3089a9820a1810..8227de75f19114 100644 --- a/third_party/xla/xla/client/executable_build_options.cc +++ b/third_party/xla/xla/client/executable_build_options.cc @@ -15,16 +15,26 @@ limitations under the License. #include "xla/client/executable_build_options.h" +#include #include #include #include #include +#include "absl/log/check.h" #include "absl/strings/str_format.h" #include "xla/debug_options_flags.h" #include "xla/execution_options_util.h" +#include "xla/layout_util.h" +#include "xla/service/compilation_environments.h" +#include "xla/service/computation_placer.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/util.h" #include "xla/xla.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { @@ -151,6 +161,10 @@ StatusOr ExecutableBuildOptions::ToProto() const { "Cannot serialize " "ExecutableBuildOptions::layout_canonicalization_callback"); } + if (compile_thread_pool() != nullptr) { + return InvalidArgument( + "Cannot serialize ExecutableBuildOptions::compile_thread_pool"); + } output.set_num_replicas(num_replicas()); output.set_num_partitions(num_partitions()); output.set_use_spmd_partitioning(use_spmd_partitioning()); @@ -170,6 +184,12 @@ StatusOr ExecutableBuildOptions::ToProto() const { } *output.mutable_fdo_profile() = fdo_profile(); output.set_device_memory_size(device_memory_size()); + for (int64_t s : auto_spmd_partitioning_mesh_shape()) { + output.mutable_auto_spmd_partitioning_mesh_shape()->Add(s); + } + for (int64_t s : auto_spmd_partitioning_mesh_ids()) { + output.mutable_auto_spmd_partitioning_mesh_ids()->Add(s); + } return output; } @@ -208,6 +228,12 @@ StatusOr ExecutableBuildOptionsFromProto( input.allow_spmd_sharding_propagation_to_output()); *output.mutable_fdo_profile() = input.fdo_profile(); output.set_device_memory_size(input.device_memory_size()); + output.set_auto_spmd_partitioning_mesh_shape( + std::vector(input.auto_spmd_partitioning_mesh_shape().begin(), + input.auto_spmd_partitioning_mesh_shape().end())); + output.set_auto_spmd_partitioning_mesh_ids( + std::vector(input.auto_spmd_partitioning_mesh_ids().begin(), + input.auto_spmd_partitioning_mesh_ids().end())); return output; } diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc index 172ded25f552af..40666f7a0b6a69 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/client/xla_builder.cc @@ -937,9 +937,15 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, reshaped_dynamic_dimensions); // Eliminate the size one dimensions. - TF_ASSIGN_OR_RETURN( - XlaOp reshaped_operand, - ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1)); + // The added reshape reduces the rank of the tensor. Hence we cannot directly + // apply the broadcast's sharding on reshape. + XlaOp reshaped_operand; + { + XlaScopedShardingAssignment scoped_sharding(this, std::nullopt); + TF_ASSIGN_OR_RETURN( + reshaped_operand, + ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1)); + } // Broadcast 'reshape' up to the larger size. return InDimBroadcast(broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -1002,15 +1008,18 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape, GetShapePtr(updated_lhs)); - if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) { - TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(shape, updated_lhs)); - } TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape, GetShapePtr(updated_rhs)); - if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) { - TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(shape, updated_rhs)); + if (!updated_lhs_shape->is_unbounded_dynamic() && + !updated_rhs_shape->is_unbounded_dynamic()) { + if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(shape, updated_lhs)); + } + if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(shape, updated_rhs)); + } } if (binop == HloOpcode::kCompare) { @@ -2495,6 +2504,25 @@ StatusOr XlaBuilder::SortInternal(const Shape& shape, return AddInstruction(std::move(instr), HloOpcode::kSort, operands); } +XlaOp XlaBuilder::TopK(XlaOp operand, int64_t k, bool largest) { + return ReportErrorOrReturn([&]() -> StatusOr { + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferTopKShape(*operand_shape, k)); + return TopKInternal(shape, operand, k, largest); + }); +} + +StatusOr XlaBuilder::TopKInternal(const Shape& shape, XlaOp operand, + int64_t k, bool largest) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_k(k); + instr.set_largest(largest); + return AddInstruction(std::move(instr), HloOpcode::kTopK, {operand}); +} + XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -3910,7 +3938,6 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) { XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); Shape shape = *operand_shape; @@ -5210,6 +5237,10 @@ XlaOp Sort(absl::Span operands, const XlaComputation& comparator, is_stable); } +XlaOp TopK(XlaOp operand, int64_t k, bool largest) { + return operand.builder()->TopK(operand, k, largest); +} + XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) { return min.builder()->Clamp(min, operand, max); } diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index cbc0259bea7944..aca638833e7097 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -901,6 +901,10 @@ class XlaBuilder { const XlaComputation& comparator, int64_t dimension, bool is_stable); + XlaOp TopK(XlaOp operand, int64_t k, bool largest); + virtual StatusOr TopKInternal(const Shape& shape, XlaOp operand, + int64_t k, bool largest); + XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); XlaOp Map(absl::Span operands, const XlaComputation& computation, @@ -1532,6 +1536,7 @@ class XlaBuilder { friend XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64_t dimension, bool is_stable); + friend XlaOp TopK(XlaOp operand, int64_t k, bool largest); friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -2674,6 +2679,26 @@ XlaOp Rev(XlaOp operand, absl::Span dimensions); XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64_t dimension = -1, bool is_stable = false); +// Enqueues a topk instruction onto the computation. TopK returns the largest +// 'k' values and their indices along the last dimension of the 'operand' if +// `lagest=true` or the smallest `k` values if `largest=false`. +// +// * If the operand is a rank-1 tensor (an array), the result is a tuple that +// consists of: +// * a sorted array with the top 'k' elements. +// * an array containing the indices of the k elements. +// For example, if the input is [0.1, 0.3, 0.2] and k == 2, the output tuple +// is ([0.3, 0.2], [1, 2]). +// * If the operand has higher rank, the result is a tuple that consists of: +// * a tensor equivalent to one produced by sorting the operand along the last +// dimension and slicing that dimension to only the top 'k' values. The last +// dimension is sorted as in the rank-1 case. +// * a tensor containing the indices of the top 'k' values along the last +// dimension. +// For example, if the input is [0.1, 0.3, 0.2][0.5, 0.4, 0.6] and k == 1, the +// output tuple is ([0.3][0.6], [1][2]). +XlaOp TopK(XlaOp operand, int64_t k, bool largest); + // Enqueues a clamp instruction onto the computation. XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index 60503f890be013..98fafbea51d3c3 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "xla/client/sharding_builder.h" #include "xla/client/value_inference.h" #include "xla/client/xla_computation.h" @@ -31,9 +32,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/statusor.h" +#include "xla/test.h" #include "xla/test_helpers.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -197,6 +202,16 @@ TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { GmockMatch(m::Add(m::Parameter(), m::Broadcast(m::Constant())))); } +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcastReversed) { + XlaBuilder b(TestName()); + XlaOp x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + Add(ConstantR0(&b, 1.0), x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + GmockMatch(m::Add(m::Broadcast(m::Constant()), m::Parameter()))); +} + TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); @@ -1524,5 +1539,353 @@ TEST_F(XlaBuilderTest, InvalidSharding) { HasSubstr("Number of tile assignment dimensions (excluding " "subgroups) is different than the input rank")); } + +TEST_F(XlaBuilderTest, TopKDimensions) { + XlaBuilder b(TestName()); + int64_t k = 1; + int64_t largest = true; + TopK(Parameter(&b, 0, ShapeUtil::MakeShape(F32, {6, 8}), "p0"), k, largest); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(root->opcode() == HloOpcode::kTopK); + EXPECT_TRUE(root->shape().IsTuple()); + EXPECT_EQ(root->shape().tuple_shapes_size(), 2); + EXPECT_EQ(root->shape().tuple_shapes(0).rank(), 2); + EXPECT_EQ(root->shape().tuple_shapes(1).rank(), 2); + EXPECT_EQ(root->shape().tuple_shapes(0).dimensions(0), 6); + EXPECT_EQ(root->shape().tuple_shapes(0).dimensions(1), k); + EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(0), 6); + EXPECT_EQ(root->shape().tuple_shapes(1).dimensions(1), k); +} + +TEST_F(XlaBuilderTest, UnboundedAbs) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr expected = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Abs(Parameter(&b, 0, operand.value(), "operand")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedAdd) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Add(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedAddUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Add(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedDiv) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Div(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedDivUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Div(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedDot) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[?, 10]"); + StatusOr expected = ParseShape("f32[?, 10]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + + Dot(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + ASSERT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedDotGeneral) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, <=3, ?]"); + StatusOr rhs = ParseShape("f32[2, 4, 5]"); + StatusOr expected = ParseShape("f32[?, <=3, 5]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), dnums); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + ASSERT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedExp) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr expected = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Exp(Parameter(&b, 0, operand.value(), "operand")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedMax) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Max(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedMaxUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Max(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedMul) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Mul(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedMulUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Mul(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedPow) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Pow(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedPowUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Pow(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedReduce) { + XlaBuilder b(TestName()); + XlaOp input0 = Parameter(&b, 0, ParseShape("f32[7, 5]").value(), "input0"); + XlaOp input1 = Parameter(&b, 1, ParseShape("f32[?, 5]").value(), "input1"); + XlaOp input2 = Parameter(&b, 2, ParseShape("f32[7, ?]").value(), "input2"); + XlaOp init = Parameter(&b, 3, ShapeUtil::MakeShape(F32, {}), "init"); + + XlaBuilder bsum(TestName()); + XlaOp arg0 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "arg0"); + XlaOp arg1 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "arg1"); + XlaOp arg2 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "arg2"); + XlaOp arg3 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "arg3"); + XlaOp arg4 = Parameter(&bsum, 4, ShapeUtil::MakeShape(F32, {}), "arg4"); + XlaOp arg5 = Parameter(&bsum, 5, ShapeUtil::MakeShape(F32, {}), "arg5"); + + std::vector output_operands = {Add(arg0, arg1), Add(arg2, arg3), + Add(arg4, arg5)}; + Tuple(&bsum, absl::MakeSpan(output_operands)); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + Reduce(&b, {input0, input1, input2}, {init, init, init}, sum, {1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + Shape shape = ShapeUtil::MakeShape(F32, {7}, {false}); + Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + EXPECT_TRUE(ShapeUtil::Equal(result, expected)); +} + +TEST_F(XlaBuilderTest, UnboundedSlice) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, <=3, ?]"); + StatusOr expected = ParseShape("f32[1, <=2, 3]"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Slice(Parameter(&b, 0, operand.value(), "operand"), + /*start_indices=*/{0, 1, 2}, + /*limit_indices=*/{1, 3, 5}, + /*strides=*/{1, 1, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto result = module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedSub) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[1, ?, 2, ?, <=2, ?, ?]"); + StatusOr rhs = ParseShape("f32[?, 1, ?, 2, ?, <=2, ?]"); + StatusOr expected = ParseShape("f32[?, ?, 2, 2, <=2, <=2, ?]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + ASSERT_IS_OK(expected.status()); + Sub(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanString(result) + << " expected: " << ShapeUtil::HumanString(expected.value()); +} + +TEST_F(XlaBuilderTest, UnboundedSubUnsupportedImplicitBroadcast) { + XlaBuilder b(TestName()); + StatusOr lhs = ParseShape("f32[?, 10]"); + StatusOr rhs = ParseShape("f32[1]"); + ASSERT_IS_OK(lhs.status()); + ASSERT_IS_OK(rhs.status()); + Sub(Parameter(&b, 0, lhs.value(), "lhs"), + Parameter(&b, 1, rhs.value(), "rhs"), /*broadcast_dimensions=*/{1}); + StatusOr> build_status = BuildHloModule(&b); + EXPECT_FALSE(build_status.ok()); + EXPECT_THAT(build_status.status().message(), + HasSubstr("Unbounded dynamic shapes not supported")); +} + +TEST_F(XlaBuilderTest, UnboundedTranspose) { + XlaBuilder b(TestName()); + StatusOr operand = ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}"); + StatusOr expected = ParseShape("f32[<=2, 1, ?, 2, ?]{0,2,3,4,1}"); + ASSERT_IS_OK(operand.status()); + ASSERT_IS_OK(expected.status()); + Transpose(Parameter(&b, 0, operand.value(), "operand"), + /*permutation=*/{4, 0, 3, 2, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ShapeUtil::Equal(result, expected.value())) + << "result: " << ShapeUtil::HumanStringWithLayout(result) + << " expected: " << ShapeUtil::HumanStringWithLayout(expected.value()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index e7b4805aa2a34f..4a1d998e5d76a4 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -136,6 +136,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_dumping(true); opts.set_xla_gpu_enable_xla_runtime_executable(true); + opts.set_xla_gpu_enable_custom_fusions(false); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); @@ -200,7 +201,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_collect_cost_model_stats(false); opts.set_xla_gpu_enable_split_k_autotuning(true); - opts.set_xla_gpu_single_wave_autotuning(true); opts.set_xla_gpu_enable_reduction_epilogue_fusion(true); opts.set_xla_gpu_enable_nccl_clique_optimization(false); opts.set_xla_gpu_cublas_fallback(true); @@ -211,7 +211,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_target_config_filename(""); opts.set_xla_gpu_enable_cub_radix_sort(true); opts.set_xla_gpu_enable_cudnn_layer_norm(false); - + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000); return opts; } @@ -1066,6 +1066,18 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_xla_runtime_executable), debug_options->xla_gpu_enable_xla_runtime_executable(), "Whether to enable XLA runtime for XLA:GPU backend")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_custom_fusions", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_custom_fusions), + debug_options->xla_gpu_enable_custom_fusions(), + "Whether to enable XLA custom fusions")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_custom_fusions_re", + string_setter_for(&DebugOptions::set_xla_gpu_enable_custom_fusions_re), + debug_options->xla_gpu_enable_custom_fusions_re(), + "Limits custom fusion only to fusions which match this regular " + "expression. Default is all custom fusions registerered in a current " + "process.")); flag_list->push_back( tsl::Flag("xla_gpu_enable_gpu2_runtime", bool_setter_for(&DebugOptions::set_xla_gpu_enable_gpu2_runtime), @@ -1341,13 +1353,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_split_k_autotuning), debug_options->xla_gpu_enable_split_k_autotuning(), "Enable split_k autotuning for triton gemms.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_single_wave_autotuning", - bool_setter_for(&DebugOptions::set_xla_gpu_single_wave_autotuning), - debug_options->xla_gpu_single_wave_autotuning(), - "Enable single \"wave\" autotuning. This uses more memory for " - "compilation, but utilizes CPU cores better, so compilation can be " - "faster.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_reduction_epilogue_fusion", @@ -1415,6 +1420,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_cub_radix_sort), debug_options->xla_gpu_enable_cub_radix_sort(), "Enable radix sort using CUB for simple shapes")); + flag_list->push_back(tsl::Flag( + "xla_gpu_threshold_for_windowed_einsum_mib", + int64_setter_for( + &DebugOptions::set_xla_gpu_threshold_for_windowed_einsum_mib), + debug_options->xla_gpu_threshold_for_windowed_einsum_mib(), + "Threshold to enable windowed einsum (collective matmul) in MB." + "Default is 100000")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/executable_run_options.cc b/third_party/xla/xla/executable_run_options.cc index 795c5fc4176431..0cb33fce343363 100644 --- a/third_party/xla/xla/executable_run_options.cc +++ b/third_party/xla/xla/executable_run_options.cc @@ -124,6 +124,17 @@ ExecutableRunOptions::gpu_executable_run_options() const { return gpu_executable_run_options_; } +ExecutableRunOptions& ExecutableRunOptions::set_cpu_executable_run_options( + const cpu::CpuExecutableRunOptions* cpu_executable_run_options) { + cpu_executable_run_options_ = cpu_executable_run_options; + return *this; +} + +const cpu::CpuExecutableRunOptions* +ExecutableRunOptions::cpu_executable_run_options() const { + return cpu_executable_run_options_; +} + ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { rng_seed_ = rng_seed; return *this; diff --git a/third_party/xla/xla/executable_run_options.h b/third_party/xla/xla/executable_run_options.h index 31ba23bf3b7a14..861e13d2a5c02f 100644 --- a/third_party/xla/xla/executable_run_options.h +++ b/third_party/xla/xla/executable_run_options.h @@ -51,6 +51,10 @@ class DeviceAssignment; class ExecutionProfile; class Shape; +namespace cpu { +class CpuExecutableRunOptions; +} // namespace cpu + namespace gpu { class GpuExecutableRunOptions; } // namespace gpu @@ -210,6 +214,12 @@ class ExecutableRunOptions { return recv_device_memory_function_; } + // CPU-backend specific options. These are kept out-of-line to avoid bloating + // the size of this dependency for CPU-only AOT builds. + ExecutableRunOptions& set_cpu_executable_run_options( + const cpu::CpuExecutableRunOptions* cpu_executable_run_options); + const cpu::CpuExecutableRunOptions* cpu_executable_run_options() const; + // GPU-backend specific options. These are kept out-of-line to avoid bloating // the size of this dependency for CPU-only AOT builds. ExecutableRunOptions& set_gpu_executable_run_options( @@ -231,6 +241,7 @@ class ExecutableRunOptions { SendDeviceMemoryFunction* send_device_memory_function_ = nullptr; RecvDeviceMemoryFunction* recv_device_memory_function_ = nullptr; RunId run_id_; + const cpu::CpuExecutableRunOptions* cpu_executable_run_options_ = nullptr; const gpu::GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr; }; diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index af726018ff3db7..53770a9618d2dd 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -18,13 +18,17 @@ cc_library( hdrs = ["call_frame.h"], visibility = ["//visibility:public"], deps = [ + ":api", + "//xla:status", "//xla:types", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", + "//xla/service:executable", "//xla/stream_executor:device_memory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], @@ -32,7 +36,6 @@ cc_library( cc_library( name = "ffi", - srcs = ["ffi.cc"], hdrs = ["ffi.h"], visibility = ["//visibility:public"], deps = [ @@ -44,6 +47,34 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", + "//xla/hlo/ir:hlo", + "//xla/runtime:memref_view", + "//xla/service:executable", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "ffi_api", + srcs = ["ffi_api.cc"], + hdrs = ["ffi_api.h"], + visibility = ["//visibility:public"], + deps = [ + ":api", + ":call_frame", + "//xla:status", + "//xla:statusor", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/ffi/api:c_api", + "//xla/ffi/api:c_api_internal", + "//xla/hlo/ir:hlo", "//xla/runtime:memref_view", "//xla/service:executable", "//xla/stream_executor:device_memory", @@ -60,11 +91,10 @@ xla_cc_test( name = "ffi_test", srcs = ["ffi_test.cc"], deps = [ - ":api", ":call_frame", ":ffi", + ":ffi_api", "//xla:xla_data_proto_cc", - "//xla/ffi/api:c_api", "//xla/service:executable", "//xla/stream_executor:device_memory", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index d18f89ed489303..fa35ce81f57128 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -1,3 +1,4 @@ +load("//xla:xla.bzl", "xla_cc_test") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") @@ -5,6 +6,22 @@ package( default_visibility = ["//visibility:public"], ) +#===-------------------------------------------------------------------------------------------===// +# Public XLA FFI API +#===-------------------------------------------------------------------------------------------===// + +# XLA FFI is a header only library that does not have any dependencies on XLA. The intent is that +# users that do want to register custom FFI handlers with XLA should copy these headers to their +# project, build a shared object with an XLA FFI handler implementation, and load it at run time. +# +# `api.h` and `ffi.h` headers provide a C++ library for decoding XLA FFI C API structs into a more +# user friendly C++ types. Shared objects defining XLA FFI handlers should be built with private +# symbol visibility to avoid potential ODR violations coming from template instantiations of +# different XLA FFI versions. +# +# `ffi.h` defines builtin decoding for canonical XLA types, but users can add their own decodings +# with template specializations. + filegroup( name = "api_headers", srcs = ["api.h"], @@ -46,3 +63,26 @@ cc_library( ":c_api", ], ) + +#===-------------------------------------------------------------------------------------------===// +# Internal tests for XLA FFI API +#===-------------------------------------------------------------------------------------------===// + +xla_cc_test( + name = "ffi_test", + srcs = ["ffi_test.cc"], + deps = [ + ":ffi", + "//xla:xla_data_proto_cc", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/log:check", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 46105cf58dd602..b37a170f57638d 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -18,18 +18,21 @@ limitations under the License. #include #include +#include #include #include #include #include #include #include +#include #include #include #include #include #include #include +#include #include // This is a header-only base C++ library that defines templates for decoding @@ -54,6 +57,22 @@ limitations under the License. #include "xla/ffi/api/c_api.h" +#if __has_attribute(always_inline) +#define XLA_ATTRIBUTE_ALWAYS_INLINE inline __attribute__((always_inline)) +#elif defined(_MSC_VER) +#define XLA_ATTRIBUTE_ALWAYS_INLINE __forceinline +#else +#define XLA_ATTRIBUTE_ALWAYS_INLINE inline +#endif + +#if __has_attribute(noinline) +#define XLA_ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) +#elif defined(_MSC_VER) +#define XLA_ATTRIBUTE_NEVER_INLINE __declspec(noinline) +#else +#define XLA_ATTRIBUTE_NEVER_INLINE +#endif + namespace xla::ffi { // Forward declare template defined below. @@ -147,17 +166,61 @@ XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, // Type tags for distinguishing handler argument types //===----------------------------------------------------------------------===// +// Forward declare. +class Dictionary; + namespace internal { +// WARNING: A lot of template metaprogramming on top of C++ variadic templates +// parameter packs. We need this to be able to pattern match FFI handler +// signature at compile time. + +// A type tag to forward all remaining args as `RemainingArgs`. +struct RemainingArgsTag {}; + // A type tag to distinguish arguments tied to the attributes in the // `Binding` variadic template argument. template struct AttrTag {}; +// A type tag to forward all attributes as `Dictionary` (and optionally decode +// it into a custom struct). +template +struct AttrsTag {}; + // A type tag to distinguish arguments extracted from an execution context. template struct CtxTag {}; +//----------------------------------------------------------------------------// +// A template for counting tagged arguments in the Ts pack (i.e. attributes). +//----------------------------------------------------------------------------// + +template