From e63ccd3cbb9e2479af94a69a0e2c9bb9b59a54e4 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 24 Oct 2023 10:47:23 -0700 Subject: [PATCH 01/36] Install CUDA 12.2 on Windows (#18044) ### Description ### Motivation and Context --- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- .../azure-pipelines/post-merge-jobs.yml | 2 +- .../templates/jobs/set-winenv.yml | 40 +++++++++++-------- .../azure-pipelines/win-gpu-ci-pipeline.yml | 6 +-- .../win-gpu-reduce-op-ci-pipeline.yml | 2 +- ...tup_env_cuda_11.bat => setup_env_cuda.bat} | 6 +++ 6 files changed, 36 insertions(+), 22 deletions(-) rename tools/ci_build/github/windows/{setup_env_cuda_11.bat => setup_env_cuda.bat} (53%) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index fdd8c0933373..b4edf088f31b 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -193,7 +193,7 @@ stages: DoCompliance: ${{ parameters.DoCompliance }} DoEsrp: ${{ parameters.DoEsrp }} stage_name_suffix: gpu - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 msbuildPlatform: x64 packageName: x64-cuda diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index d24b0e053963..2a94499c7a26 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -67,7 +67,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index ca5a52fa61ed..0c8fb91a24a3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -1,19 +1,27 @@ parameters: -- name: EnvSetupScript - type: string - -- name: DownloadCUDA - type: boolean - default: false + - name: EnvSetupScript + type: string + - name: DownloadCUDA + type: boolean + default: false + - name: PrimaryCUDAVersion + type: string + default: '11.8' + - name: SecondaryCUDAVersion + type: string + default: '12.2' steps: -- ${{ if eq(parameters.DownloadCUDA, 'true') }}: - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" $(Agent.TempDirectory) - -- task: BatchScript@1 - displayName: 'setup env' - inputs: - filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\${{ parameters.EnvSetupScript }}' - modifyEnvironment: true - workingFolder: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.DownloadCUDA, 'true') }}: + - powershell: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.PrimaryCUDAVersion }}" $(Agent.TempDirectory) + displayName: 'Download Primary CUDA SDK v${{ parameters.PrimaryCUDAVersion }}' + - powershell: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.SecondaryCUDAVersion }}" $(Agent.TempDirectory) + displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' + - task: BatchScript@1 + displayName: 'setup env' + inputs: + filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\${{ parameters.EnvSetupScript }}' + modifyEnvironment: true + workingFolder: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 07b5388ea5cd..ae2a4b4cead3 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -40,7 +40,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 @@ -57,7 +57,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 additionalBuildFlags: --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --skip_onnx_tests --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 msbuildPlatform: x64 @@ -76,7 +76,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags additionalBuildFlags: --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml index b5db8a520140..d0f9772da7ad 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml @@ -10,7 +10,7 @@ jobs: BuildConfig: 'MinSizeRel' variables: MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] timeoutInMinutes: 120 diff --git a/tools/ci_build/github/windows/setup_env_cuda_11.bat b/tools/ci_build/github/windows/setup_env_cuda.bat similarity index 53% rename from tools/ci_build/github/windows/setup_env_cuda_11.bat rename to tools/ci_build/github/windows/setup_env_cuda.bat index 1308e43a4f6d..96569cbe0f64 100644 --- a/tools/ci_build/github/windows/setup_env_cuda_11.bat +++ b/tools/ci_build/github/windows/setup_env_cuda.bat @@ -6,4 +6,10 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { } else { set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% } +@REM The default version is still cuda v11.8, because set cuda v12.2 after it +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ { + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +} else { + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 +} set GRADLE_OPTS=-Dorg.gradle.daemon=false From abb329179adae0029ef492c251984fcfd78224c4 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 24 Oct 2023 10:50:12 -0700 Subject: [PATCH 02/36] Update win-wasm-ci.yml: increase the timeout value (#18023) --- tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 84c910ba5878..a5925d16564f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -20,7 +20,7 @@ parameters: default: false - name: TimeoutInMinutes - default: 180 + default: 240 - name: BuildJsep type: boolean From efa0cc2562c28e6376717b46ebc83dd29b68d348 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Tue, 24 Oct 2023 10:58:54 -0700 Subject: [PATCH 03/36] implement isinf20 and isnan20 (#17874) --- docs/OperatorKernels.md | 6 +- include/onnxruntime/core/framework/float8.h | 5 +- .../providers/cpu/cpu_execution_provider.cc | 42 +++-- .../core/providers/cpu/tensor/isinf.cc | 101 ++++++++++- .../core/providers/cpu/tensor/isnan.cc | 81 ++++++++- .../test/providers/cpu/tensor/isinf_test.cc | 164 ++++++++++++------ .../test/providers/cpu/tensor/isnan_test.cc | 85 +++++++-- .../onnx_backend_test_series_filters.jsonc | 6 - 8 files changed, 389 insertions(+), 101 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index dea71d81f8df..ba610515ac28 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -156,8 +156,10 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| |||[1, 12]|**T** = tensor(float)| diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 0fd04f28d44b..dd607cbbc695 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -208,9 +208,10 @@ struct Float8E4M3FNUZ { val = static_cast((b & 0x80000000) >> 24); // sign if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { + // the highest available value val |= 0x7F; } else { - // infinity + // NaN val = 0x80; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN @@ -362,8 +363,10 @@ struct Float8E5M2 { val = (b & 0x80000000) >> 24; // sign if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { + // the highest available value val |= 0x7B; } else { + // the infinity val |= 0x7C; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3d03abf5b7eb..a54d999a100b 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence); @@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -960,6 +960,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); +#endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -1492,7 +1502,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Dropout)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index bc99caa8036c..1b449f46927a 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -14,15 +14,38 @@ namespace onnxruntime { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0, - float, double); +using IsInfTypesOpset10 = TypeList; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0, + IsInfTypesOpset10); + +using IsInfTypesOpset20 = + TypeList< + float, + double +#if !defined(DISABLE_FLOAT8_TYPES) + , + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ +#endif + >; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, + kOnnxDomain, + IsInf, + 20, + Input, + 0, + IsInfTypesOpset20); } // namespace op_kernel_type_control class IsInf final : public OpKernel { public: - using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, - IsInf, Input, 0); + using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 10, Input, 0); + using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 20, Input, 0); explicit IsInf(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; @@ -30,14 +53,25 @@ class IsInf final : public OpKernel { private: int64_t detect_positive_{1}; int64_t detect_negative_{1}; + int opset_; }; -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( IsInf, 10, + 19, KernelDefBuilder() .TypeConstraint("T1", - BuildKernelDefConstraintsFromTypeList()) + BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_CPU_OPERATOR_KERNEL( + IsInf, + 20, + KernelDefBuilder() + .TypeConstraint("T1", + BuildKernelDefConstraintsFromTypeList()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); @@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); status = info.GetAttr("detect_negative", &detect_negative_); ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + opset_ = info.node().SinceVersion(); } namespace isinf_internal { @@ -78,6 +113,49 @@ struct ComputeDispatchTarget { } } }; + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto& dims = X.Shape(); + auto input = ConstEigenVectorMap(static_cast(static_cast(X.Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.00 + if (detect_positive && detect_negative) { + output.array() = input.array() == 0b01111100 || input.array() == 0b11111100; + } else if (detect_positive) { + output.array() = input.array() == 0b01111100; + } else if (detect_negative) { + output.array() = input.array() == 0b11111100; + } else { + output.array() = false; + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; +#endif } // namespace isinf_internal Status IsInf::Compute(OpKernelContext* context) const { @@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const { using namespace isinf_internal; - utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; - dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + if (opset_ < 20) { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } else { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc index 33d0f8eb6c1a..34495e382278 100644 --- a/onnxruntime/core/providers/cpu/tensor/isnan.cc +++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc @@ -20,10 +20,20 @@ namespace onnxruntime { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ IsNaN); +#define ADD_TYPED_ISNAN_OP_13(data_type) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + IsNaN, \ + 13, 19, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + IsNaN); + #define ADD_TYPED_ISNAN_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ IsNaN, \ - 13, \ + 20, \ data_type, \ KernelDefBuilder() \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ @@ -33,10 +43,20 @@ namespace onnxruntime { ADD_TYPED_ISNAN_OP_9(float); ADD_TYPED_ISNAN_OP_9(double); ADD_TYPED_ISNAN_OP_9(MLFloat16); +ADD_TYPED_ISNAN_OP_13(float); +ADD_TYPED_ISNAN_OP_13(double); +ADD_TYPED_ISNAN_OP_13(MLFloat16); ADD_TYPED_ISNAN_OP(float); ADD_TYPED_ISNAN_OP(double); ADD_TYPED_ISNAN_OP(MLFloat16); +#if !defined(DISABLE_FLOAT8_TYPES) +ADD_TYPED_ISNAN_OP(Float8E4M3FN); +ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ); +ADD_TYPED_ISNAN_OP(Float8E5M2); +ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ); +#endif + template Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); @@ -70,4 +90,63 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.1111.111 + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = + ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.{01, 10, 11} + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} +#endif } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index ddb392eb82e1..2e583c5d2547 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -17,85 +17,137 @@ constexpr double DOUBLE_INF = std::numeric_limits::infinity(); constexpr double DOUBLE_NINF = -std::numeric_limits::infinity(); constexpr double DOUBLE_NAN = std::numeric_limits::quiet_NaN(); -TEST(IsInfTest, test_isinf_float) { - // Defaults for detect_negative = 1 - // detect_positive = 1 - OpTester test("IsInf", 10); +template +void run_is_inf_test(int opset, int64_t detect_positive, int64_t detect_negative, const std::initializer_list& input, const std::initializer_list& output) { + OpTester test("IsInf", opset); + test.AddAttribute("detect_positive", detect_positive); + test.AddAttribute("detect_negative", detect_negative); + test.AddInput("X", {onnxruntime::narrow(input.size())}, input); + test.AddOutput("Y", {onnxruntime::narrow(output.size())}, output); + test.Run(); +} - std::vector input_dim{6}; - std::vector input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_float10) { + std::initializer_list input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(10, 1, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, true, true}); - test.Run(); +TEST(IsInfTest, test_isinf_float20) { + std::initializer_list input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); } -TEST(IsInfTest, test_isinf_double) { - // Defaults for detect_negative = 1 - // detect_positive = 1 - OpTester test("IsInf", 10); +TEST(IsInfTest, test_isinf_double10) { + std::initializer_list input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(10, 1, 1, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_double20) { + std::initializer_list input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, true, true}); - test.Run(); +TEST(IsInfTest, test_isinf_positive_float10) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(10, 1, 0, input, output); } -TEST(IsInfTest, test_isinf_positive_float) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_negative", 0); +TEST(IsInfTest, test_isinf_positive_float20) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_positive_double10) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(10, 1, 0, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, false, true}); - test.Run(); +TEST(IsInfTest, test_isinf_positive_double20) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_float10) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(10, 0, 1, input, output); } -TEST(IsInfTest, test_isinf_positive_double) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_negative", 0); +TEST(IsInfTest, test_isinf_negative_float20) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_negative_double10) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(10, 0, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, false, true}); - test.Run(); +TEST(IsInfTest, test_isinf_negative_double20) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); } -TEST(IsInfTest, test_isinf_negative_float) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_positive", 0); +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(IsInfTest, test_Float8E4M3FN) { + std::initializer_list input = { + Float8E4M3FN(-1.0f), Float8E4M3FN(FLOAT_NAN, false), Float8E4M3FN(1.0f), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 1, 1, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_Float8E4M3FNUZ) { + std::initializer_list input = { + Float8E4M3FNUZ(-1.0f), Float8E4M3FNUZ(FLOAT_NAN, false), Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 1, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, false, false, true, false}); - test.Run(); +TEST(IsInfTest, test_Float8E5M2_detect_both) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, true, false, true, false, true}; + run_is_inf_test(20, 1, 1, input, output); } -TEST(IsInfTest, test_isinf_negative_double) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_positive", 0); +TEST(IsInfTest, test_Float8E5M2_detect_positive) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_Float8E5M2_detect_negative) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, true, false, true, false, false}; + run_is_inf_test(20, 0, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, false, false, true, false}); - test.Run(); +TEST(IsInfTest, test_Float8E5M2_none) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 0, 0, input, output); } +TEST(IsInfTest, test_Float8E5M2FNUZ) { + std::initializer_list input = { + Float8E5M2FNUZ(-1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(FLOAT_NAN, false), Float8E5M2FNUZ(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 1, 1, input, output); +} +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc index 0dffc452b519..0f1e5c07cdd9 100644 --- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc @@ -9,29 +9,84 @@ namespace onnxruntime { namespace test { -TEST(IsNaNOpTest, IsNaNFloat) { - OpTester test("IsNaN", 9, kOnnxDomain); - std::vector dims{2, 2}; - test.AddInput("X", dims, {1.0f, NAN, 2.0f, NAN}); - test.AddOutput("Y", dims, {false, true, false, true}); +template +void run_is_nan_test(int opset, const std::vector& dims, const std::initializer_list& input, const std::initializer_list& output) { + OpTester test("IsNaN", opset, kOnnxDomain); + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); test.Run(); } -TEST(IsNaNOpTest, IsNaNFloat16) { - OpTester test("IsNaN", 9, kOnnxDomain); +TEST(IsNaNOpTest, IsNaNFloat9) { std::vector dims{2, 2}; - test.AddInput("X", dims, std::initializer_list({MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN})); - test.AddOutput("Y", dims, {false, true, false, true}); - test.Run(); + std::initializer_list input = {1.0f, NAN, 2.0f, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(9, dims, input, output); } -TEST(IsNaNOpTest, IsNaNDouble) { - OpTester test("IsNaN", 9, kOnnxDomain); +TEST(IsNaNOpTest, IsNaNFloat20) { std::vector dims{2, 2}; - test.AddInput("X", dims, {1.0, NAN, 2.0, NAN}); - test.AddOutput("Y", dims, {false, true, false, true}); - test.Run(); + std::initializer_list input = {1.0f, NAN, 2.0f, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNFloat16_9) { + std::vector dims{2, 2}; + std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(9, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNFloat16_20) { + std::vector dims{2, 2}; + std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNDouble9) { + std::vector dims{2, 2}; + std::initializer_list input = {1.0, NAN, 2.0, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(9, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNDouble20) { + std::vector dims{2, 2}; + std::initializer_list input = {1.0, NAN, 2.0, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); } +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(IsNaNOpTest, IsNaNFloat8E4M3FN) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E4M3FN(1.0f), Float8E4M3FN(-NAN), Float8E4M3FN(2.0f), Float8E4M3FN(NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaN_Float8E4M3FNUZ) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(-NAN), Float8E4M3FNUZ(2.0f), Float8E4M3FNUZ(-NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNFloat8E5M2) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E5M2(1.0f), Float8E5M2(-NAN), Float8E5M2(2.0f), Float8E5M2(NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaN_Float8E5M2FNUZ) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(-NAN), Float8E5M2FNUZ(2.0f), Float8E5M2FNUZ(NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index b3161a42bb3e..44db7c0078cf 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -283,12 +283,6 @@ "^test_dft_axis", "^test_dft", "^test_dft_inverse", - "^test_isinf", - "^test_isinf_float16", - "^test_isinf_negative", - "^test_isinf_positive", - "^test_isnan", - "^test_isnan_float16", "^test_reduce_max_bool_inputs", "^test_reduce_min_bool_inputs", "^test_reduce_min_empty_set", From 6ec45f2ba590fabad99159a44fd6e48a5a9b03f0 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 24 Oct 2023 13:04:08 -0700 Subject: [PATCH 04/36] Merge aiinfra-linux-ARM64-CPU-2019 and onnxruntime-linux-ARM64-CPU-2019 (#18069) ### Description Merge aiinfra-linux-ARM64-CPU-2019 and onnxruntime-linux-ARM64-CPU-2019 machines to a single one to ease management. --- .../github/azure-pipelines/py-package-test-pipeline.yml | 2 +- .../azure-pipelines/templates/linux-cpu-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/templates/py-packaging-stage.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 2161a9205f22..c8aac6e8b130 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -30,7 +30,7 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' - machine_pool: 'aiinfra-linux-ARM64-CPU-2019' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' base_image: 'arm64v8/almalinux:8' devtoolset_rootpath: /opt/rh/gcc-toolset-12/root ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 51d3a9ebc218..1cc5c48c5513 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -47,7 +47,7 @@ stages: OnnxruntimeCFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -O3 -Wl,--strip-all' OnnxruntimeCXXFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -O3 -Wl,--strip-all' OnnxruntimeNodejsBindingArch: 'arm64' - PoolName: 'aiinfra-linux-ARM64-CPU-2019' + PoolName: 'onnxruntime-linux-ARM64-CPU-2019' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} PackageNodeJS: ${{ parameters.PackageNodeJS }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 1e28ad08a5bd..1a67ace5e85f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -484,7 +484,7 @@ stages: - template: py-linux.yml parameters: arch: 'aarch64' - machine_pool: 'aiinfra-linux-ARM64-CPU-2019' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' base_image: 'arm64v8/almalinux:8' devtoolset_rootpath: /opt/rh/gcc-toolset-12/root ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 From 76e275baf44d5bd882fd298d3b86d824eb113435 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 24 Oct 2023 15:17:36 -0700 Subject: [PATCH 05/36] Merge Cuda docker files into a single one (#18020) ### Description ### Motivation and Context --- setup.py | 8 +- .../c-api-noopenmp-packaging-pipelines.yml | 6 +- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 20 +- .../linux-gpu-tensorrt-ci-pipeline.yml | 10 +- .../linux-gpu-tensorrt-packaging-pipeline.yml | 10 +- .../templates/py-linux-gpu.yml | 11 +- .../py-packaging-linux-test-cuda.yml | 11 +- .../docker/Dockerfile.manylinux2_28_cuda | 51 +++-- .../docker/Dockerfile.manylinux2_28_cuda11 | 166 ---------------- ...kerfile.manylinux2_28_cuda11_6_tensorrt8_4 | 173 ----------------- ...kerfile.manylinux2_28_cuda11_6_tensorrt8_5 | 173 ----------------- ...kerfile.manylinux2_28_cuda11_8_tensorrt8_6 | 181 ------------------ 12 files changed, 90 insertions(+), 730 deletions(-) delete mode 100644 tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 delete mode 100644 tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 delete mode 100644 tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 delete mode 100644 tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 diff --git a/setup.py b/setup.py index 2eb8f212d730..b71836e0ee6e 100644 --- a/setup.py +++ b/setup.py @@ -192,11 +192,15 @@ def run(self): cuda_dependencies = [ "libcublas.so.11", + "libcublas.so.12", "libcublasLt.so.11", - "libcudnn.so.8", + "libcublasLt.so.12", "libcudart.so.11.0", - "libcurand.so.10", + "libcudart.so.12.0", + "libcudnn.so.8", "libcufft.so.10", + "libcufft.so.11", + "libcurand.so.10", ] rocm_dependencies = [ "librccl.so.1", diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index b4edf088f31b..129dbc833a0a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -488,13 +488,13 @@ stages: Steps: - script: | tools/ci_build/get_docker_image.py \ - --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 \ + --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ --context tools/ci_build/github/linux/docker \ - --docker-build-args "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg BUILD_UID=$( id -u )" \ + --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 --build-arg INSTALL_CUDNN=true --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ --container-registry onnxruntimebuildcache \ --multiple_repos \ --repository onnxruntimecuda118xtrt86build - displayName: "Get onnxruntimecuda118xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6" + displayName: "Get onnxruntimecuda118xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda" workingDirectory: $(Build.SourcesDirectory)/onnxruntime ContainerRegistry: onnxruntimebuildcache diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 71a580f348f6..1d4681d06438 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -58,9 +58,15 @@ jobs: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimecuda11build - task: Cache@2 @@ -154,9 +160,15 @@ jobs: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimecuda11build - task: CmdLine@2 diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 9450395f3cf7..16d4457c45eb 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -57,9 +57,15 @@ jobs: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimetensorrt86gpubuild - template: templates/linux-build-step-with-cache.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 445f739e81c4..0d58f6cee400 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -44,9 +44,15 @@ stages: submodules: recursive - template: get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimecuda118xtrt86build - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index 3d5a71284fa6..33c82b5e8965 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -36,9 +36,16 @@ jobs: - template: get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + --build-arg PLATFORM=${{ parameters.arch }} + " Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 43ed0172825b..a70e0c01e52f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -81,9 +81,16 @@ jobs: - template: get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + --build-arg PLATFORM=${{ parameters.arch }} + " Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} - task: Bash@3 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 4d9c676674a0..7b2cada73648 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -5,11 +5,10 @@ ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 ARG BASEIMAGE=nvidia/cuda:12.2.0-devel-ubi8 -ARG TRT_VERSION=8.6.1.6-1.cuda12.0 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 ARG PREPEND_PATH=/usr/local/cuda/binet - +ARG INSTALL_CUDNN=false #Build manylinux docker image begin FROM $BASEIMAGE AS runtime_base @@ -118,7 +117,7 @@ RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 FROM build_cpython AS build_cpython311 COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.0b5 +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 FROM build_cpython AS all_python COPY build_scripts/install-pypy.sh \ @@ -155,23 +154,35 @@ CMD ["/bin/bash"] #Build manylinux docker image end -#Install TensorRT 8.6.1.6 -RUN CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') &&\ - dnf -y install\ - libcudnn8-devel-*cuda${CUDA_VERSION}*\ - libcudnn8-*cuda${CUDA_VERSION}*\ - libnvinfer8-${TRT_VERSION}\ - libnvparsers8-${TRT_VERSION}\ - libnvonnxparsers8-${TRT_VERSION}\ - libnvinfer-plugin8-${TRT_VERSION}\ - libnvinfer-vc-plugin8-${TRT_VERSION}\ - libnvinfer-devel-${TRT_VERSION}\ - libnvparsers-devel-${TRT_VERSION}\ - libnvonnxparsers-devel-${TRT_VERSION}\ - libnvinfer-plugin-devel-${TRT_VERSION}\ - libnvinfer-vc-plugin-devel-${TRT_VERSION}\ - libnvinfer-headers-devel-${TRT_VERSION}\ - libnvinfer-headers-plugin-devel-${TRT_VERSION} + +#Install optinal Cudnn +RUN if [ "$INSTALL_CUDNN" = true ]; then \ + CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') && \ + dnf -y install \ + libcudnn8-devel-*cuda${CUDA_VERSION}* \ + libcudnn8-*cuda${CUDA_VERSION}* ; \ +fi + +#Install TensorRT only if TRT_VERSION is not empty +RUN if [ -n "$TRT_VERSION" ]; then \ + echo "TRT_VERSION is $TRT_VERSION" && \ + dnf -y install \ + libnvinfer8-${TRT_VERSION} \ + libnvparsers8-${TRT_VERSION} \ + libnvonnxparsers8-${TRT_VERSION} \ + libnvinfer-plugin8-${TRT_VERSION} \ + libnvinfer-vc-plugin8-${TRT_VERSION} \ + libnvinfer-devel-${TRT_VERSION} \ + libnvparsers-devel-${TRT_VERSION} \ + libnvonnxparsers-devel-${TRT_VERSION} \ + libnvinfer-plugin-devel-${TRT_VERSION} \ + libnvinfer-vc-plugin-devel-${TRT_VERSION} \ + libnvinfer-headers-devel-${TRT_VERSION} \ + libnvinfer-headers-plugin-devel-${TRT_VERSION}; \ +else \ + echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ +fi + ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 deleted file mode 100644 index 933b0211b0e6..000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 +++ /dev/null @@ -1,166 +0,0 @@ -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG POLICY=manylinux_2_28 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need both CUDA and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 deleted file mode 100644 index 003bb2324c04..000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 +++ /dev/null @@ -1,173 +0,0 @@ -ARG BASEIMAGE=nvidia/cuda:11.6.1-cudnn8-devel-centos7 -ARG POLICY=manylinux2014 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need CUDA, TensorRT and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux and TensorRT on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end - -#Install TensorRT 8.4.1.5 -#RUN yum install -y wget -RUN v="8.4.1-1.cuda11.6" &&\ - yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ - yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ - libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 deleted file mode 100644 index 0337ffc5e00a..000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 +++ /dev/null @@ -1,173 +0,0 @@ -ARG BASEIMAGE=nvidia/cuda:11.6.1-cudnn8-devel-centos7 -ARG POLICY=manylinux2014 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need CUDA, TensorRT and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux and TensorRT on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end - -#Install TensorRT 8.5.1.7 -#RUN yum install -y wget -RUN v="8.5.1-1.cuda11.8" &&\ - yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ - yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ - libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 deleted file mode 100644 index 70765c667ab8..000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 +++ /dev/null @@ -1,181 +0,0 @@ -# This file is deprecated and will be replaced by tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda -ARG BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 -ARG POLICY=manylinux_2_28 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need CUDA, TensorRT and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux and TensorRT on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython37 -COPY build_scripts/cpython-pubkeys.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.7.13 - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.0b5 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.7 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython37 /opt/_internal /opt/_internal/ -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.7.txt \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end - -#Install TensorRT 8.6.1.6 -RUN v="8.6.1.6-1.cuda11.8" && CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') \ - && dnf -y install libcudnn8-devel-*cuda$CUDA_VERSION* libcudnn8-*cuda$CUDA_VERSION* libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ - libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING "LAZY" From ae8561979f494029c863dafb67bae05639ebff60 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Tue, 24 Oct 2023 19:41:10 -0700 Subject: [PATCH 06/36] Introduce new optimizer MatMul + BatchNormalization (#17915) ### Description Introduce new ORT L1 optimizer under RewriteRule category to fuse MatMul + BatchNormalization node. This optimizer look for a specific pattern observed in one of the impacting customer models and fuse the Matmul and Batchnormalization node into a Gemm node. For details on the pattern matching and fusion please refer to the comment section of `matmul_bn_fusion.cc`. To visualize, this optimizer will replace following subgraph to a Gemm node.
               MatMul                  GEMM
                 |                       |
              Reshape ^     --->      Reshape ^
                 |                       |
            Transpose ^             Transpose ^
                 |
       BatchNormalization
Note: ^ means there can be >=0 occurrence(s) of that node.
Few example fusable pattern:
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM ->
Reshape -> Transpose
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM ->
Reshape -> Reshape
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization --->
GEMM -> Reshape -> Transpose -> Reshape
* - MatMul -> BatchNormalization ---> GEMM
Note: This optimizer may evolve in the future to be more generic in terms of the pattern matching. ### Motivation and Context - Why is this change required? What problem does it solve? One of the user of ORT+DML ep needs this to better target the model to DML. But this transformation applies more broadly, so added L1 optimizer. --- .../core/optimizer/graph_transformer_utils.cc | 2 + onnxruntime/core/optimizer/initializer.cc | 28 +- onnxruntime/core/optimizer/initializer.h | 2 +- .../core/optimizer/matmul_bn_fusion.cc | 230 +++++++++++++++ onnxruntime/core/optimizer/matmul_bn_fusion.h | 27 ++ .../test/optimizer/graph_transform_test.cc | 263 ++++++++++++++++++ .../fusion/fuse-matmul-bn-directly.onnx | Bin 0 -> 513 bytes .../fuse-matmul-bn-non-ignorable-node.onnx | Bin 0 -> 593 bytes .../fusion/fuse-matmul-bn-only-reshape.onnx | Bin 0 -> 639 bytes .../fusion/fuse-matmul-bn-only-transpose.onnx | Bin 0 -> 579 bytes .../fusion/fuse-matmul-bn-with-reshape.onnx | Bin 0 -> 709 bytes 11 files changed, 543 insertions(+), 9 deletions(-) create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.cc create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.h create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c4416068e245..5a441b1d1701 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -127,6 +128,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index c8da15f65a6d..9e807ddc7be5 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -303,24 +307,32 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } } }; - } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1ace..78e3fd6a3d24 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 000000000000..e944522c9c33 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +const std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace + +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); + + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; + } + + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + return true; + } + } + + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. + return std::nullopt; +} + +/* + * Given a MatMul node, it will verify the following pattern. + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ + * | + * BatchNormalization + * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. + */ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { + return false; + } + + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node->OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } + } + } + + return true; +} + +/* + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + + // only perform fusion if epsilon is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } + + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + matmul_node.MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete optional empty output defs. + // Delete BatchNormalization node and update the input of the child of BatchNormalization + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 000000000000..7a43483cf37d --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"MatMul"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6acf631d53cd..46b95a127b75 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -31,6 +31,7 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" @@ -1079,6 +1080,268 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } +TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } else if (node.OpType() == "BatchNormalization") { + node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + // additional non-empty output to batchNormalization + ONNX_NAMESPACE::TypeProto optional_output_tensor_type; + optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); + auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); + node.MutableOutputDefs().push_back(&arg); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + +TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + graph_utils::RemoveNode(graph, node); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["MatMul"], 1); +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa11adaac8d95db4772990bac6f5b0b072f4d5c1 GIT binary patch literal 513 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaQSngN^o%`<;52#C+4Jbu>)C2nTf? zq%5&Whz)9pkW*qwa)w`iQEp;RW>sPd&@n<{FpKkG&I7wutCoY6gH?c0DW&9@y}fs& zzrATnsojapNc+btLhLz8((JXXI_#F_bK0x1?6p(Ye{Q$n%?jIp7hc)TKdWx{F3r&H z+?8o_>{ zSDarY#0wS3FD(JeE3x?|miU(DaA|R&N(k|1rljVTWR_IMLsfEkLIt=&xX>lJIFj<> zi<1*`Qn}cHtfb7uVlX2&H8GEi4JcBUSR}*=q@iXBIVF}PXZYn8u2svy%E2nYsFc!j%G0hwo57}ap`u+&w6Oi|3GViX_qf@! zybQG0I2>y4k;iL)_kNE3_wx$&tYe@Bh!we$Lz!TZ@_d?5;5D+V6{= zV3%pwVCNR)V}GFJiJcSoRJ-K@CU%z=KCp8_4^Ax=u;n5Q3=Q_^*a;peTFMYrb=SLX z+22jE+ccNa&c@Z%zU4-=y%Q|JV2aw9{@H~1McZ4hIcs}X#K69_o&}-^6qs5{5R;0# z3hgU&)9ra*iP{&LMB2BRci1^W13F3^8u>zeTs$0%LL6L79E?EBnk2ym4Oes-Cl)RS G0S*9lCcbw7 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c361a42700a30b12eeeacefe8da3825a1e9b3ab7 GIT binary patch literal 639 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#t`IJC zIWCT*y!hhe#GF(vb|5P$GqD)V$W2Ym<6;AflqD7kaRO%ZT9OGc24NQswKn0BEW9N z`yU1v85vofU}31G0(P1R14Dy7I(C8wnU*p{)fxLfyYog2_W4&F?IlH6?J8~<+B?C5 z6Q;!T*3o{AR=O=~<8iytTpoxbP*7_rK}?#|t7^aKc8I;q6Gyw$c{gn>(pBu8 ypwSQ|4oxCLd|W&nj6xh-OdO0r%$lUY1r2o}F)k6Hi~v*yBnwvI#KOfOzySb5e#_?o literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f70ae2e6229e7a89eb90cb8d68360e35e4c32291 GIT binary patch literal 579 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaM^I7N(c#-6eZ>r7vvYG#zT}EC~=0sgtU0MSPD{$ zavc~Q7#1+Ha|Lr@D(B)z%8M^fPRvQ=Vh6I4G82ozjNH`3JT5k%NLgZ$5F6APLQaV# z$r*n6MY)MNnN^7;K<@~N!JM84^B~x_TD2Uk9IOJ2N-3i2U)s&qRAr3Ak4n`nmO_Jb( RhA@&0D@fLfg^NLe0|4BqxTOF9 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8e4bc49514548c604b36029f780f7ff1a56db148 GIT binary patch literal 709 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#E+;N@ zIU&K4qQt!7g8bstc$jj|5SWmbAQwwPYEiBOg9EbzqXW|dMs}_+E=(<497%cc#mR{| zsa)(pR#IkSF_@8?nwZDM1{5hvEE3`b(ojDLIVF}PXZYn8{d>;u)k9_-R68Ut6h7pqW!TWKOYnLXR9R84eZ!b}QchFu=&j$m#@( z2`v?{(?l2;8tl=r6FkVYlp(5YTTJaU71!F`j(KX=uE%OG<|u9N1Pe}>qNNu#?VDGY z+kf6~Y%f=BV{d;x(hjT$6x3Qu5R=rr>+N}uHrviU9%8RkaM$M1!XSGmXf#BLLlc@1 o9~Tb?qYwud69*#@vnDBUK|@_gj7tP4BLI_u(u__lTnqvn0In0|n*aa+ literal 0 HcmV?d00001 From 2c6b31c5aa05bdce26ccd1af58bb194f880166ed Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 25 Oct 2023 15:11:02 +0800 Subject: [PATCH 07/36] FP16 optimizer automatically detect DeepSpeed compatibility (#18084) ### FP16 optimizer automatically detect DeepSpeed compatibility Optimum/Transformers are using accelerate lib to prepare models, so our FP16 optimizer wrapper does not work for long time. Because the namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`, which underlying is still calling into DeepSpeed stage1and2 optimizer. This PR includes following changes: 1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the modifier registry, plus a check on its contained `optimizer` property MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3 optimizer later) 2. For DeepSpeed version > 0.9.1, we will store the source code in a version list. As long as the related function in DeepSpeed remains unchanged during its new release, we won't need manually upgrade the version check any more. If some day, the source code did not match, a warning will be raised to users, to add a new version of source code in the list. With the above change, we will have our FP16 Optimizer working again in Optimum. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/d35b4aa9-b371-46f1-98ae-73114f91179b) --- .lintrunner.toml | 2 + .../python/training/optim/_ds_code_store.py | 81 ++++++++++++++++++ .../python/training/optim/_ds_modifier.py | 85 +++++++++++++++++-- .../training/optim/_modifier_registry.py | 58 +++++++++++-- .../python/training/optim/fp16_optimizer.py | 28 ++---- 5 files changed, 223 insertions(+), 31 deletions(-) create mode 100644 orttraining/orttraining/python/training/optim/_ds_code_store.py diff --git a/.lintrunner.toml b/.lintrunner.toml index c44a66200ad1..4e5d077b08ff 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -45,6 +45,7 @@ exclude_patterns = [ 'cmake/external/**', # ignore generated flatbuffers code 'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', @@ -76,6 +77,7 @@ exclude_patterns = [ 'cmake/**', 'orttraining/*', 'onnxruntime/core/flatbuffers/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', diff --git a/orttraining/orttraining/python/training/optim/_ds_code_store.py b/orttraining/orttraining/python/training/optim/_ds_code_store.py new file mode 100644 index 000000000000..dc1e20bc3dcf --- /dev/null +++ b/orttraining/orttraining/python/training/optim/_ds_code_store.py @@ -0,0 +1,81 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# +# Copyright 2020 The Microsoft DeepSpeed Team +# +# !!!IMPORTANT: This file is a copy of the original one in DeepSpeed repo at given version, +# It is used to compare with the source code of current installed DeepSpeed during runtime. +# Please don't modify it or do any code formatting for it. +# 'orttraining/orttraining/python/training/optim/_ds_code_store.py' is removed from lintrunner config by intention. +# -------------------------------------------------------------------------- + +# Wrap code in this to make sure the indentation is correct compared with raw DeepSpeed. + +class Stage1And2_DeepSpeedZeroOptimizer_0_9_2: + + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0.0 + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + continue + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 + # Sum across all model parallel GPUs. + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.bit16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index 6b1c98cc02a5..20f4f814e547 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -10,6 +10,9 @@ # - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799 # -------------------------------------------------------------------------- +from __future__ import annotations + +import inspect import types import warnings @@ -17,12 +20,69 @@ from numpy import inf from packaging.version import Version +from ._ds_code_store import Stage1And2_DeepSpeedZeroOptimizer_0_9_2 from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads from ._multi_tensor_apply import MultiTensorApply multi_tensor_applier = MultiTensorApply(2048 * 32) +def _get_normalized_str(function) -> str: + return inspect.getsource(function) + + +def _dynamic_checks(cur_ds_version: Version, optimizer) -> bool: + _functions_to_override = ["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"] + + _version_to_source_code_map = {"0.9.2": Stage1And2_DeepSpeedZeroOptimizer_0_9_2} + + # Try to find the biggest version that is smaller than or equal to cur_ds_version. + # then compare the source code (in case the found version is the latest version supported); + # If current code does not match the found version, return False, and raise a warning to + # add the new version to the list. + versions = [Version(v) for v in _version_to_source_code_map] + sorted_versions = sorted(versions, reverse=True) + version_to_compare = None + for sv in sorted_versions: + if cur_ds_version >= sv: + version_to_compare = sv + break + + if version_to_compare is None: + warnings.warn( + "Unable to find a DeepSpeed version that is smaller than or equal to the current version " + f"{cur_ds_version}. Skip modifying optimizer.", + UserWarning, + ) + return False + + v_optimizer_cls = _version_to_source_code_map[str(version_to_compare)] + all_match = True + for func_name in _functions_to_override: + if not getattr(optimizer, func_name): + warnings.warn( + f"DeepSpeed function {func_name} is not found in optimizer. Skip modifying optimizer.", UserWarning + ) + all_match = False + cur_code_str = _get_normalized_str(getattr(optimizer, func_name)) + v_code_str = _get_normalized_str(getattr(v_optimizer_cls, func_name)) + if cur_code_str != v_code_str: + warnings.warn( + f"DeepSpeed function {func_name} has changed after version {version_to_compare}. " + f"Please append new version {cur_ds_version} in _version_to_source_code_map and _ds_code_store.py.\n" + f"---[{func_name}] Old Source Code Start----\n" + f"{v_code_str}\n" + f"---{func_name} Old Source Code End----\n" + f"---[{func_name}] New Source Code Start----\n" + f"{cur_code_str}\n" + f"---{func_name} New Source Code End----", + UserWarning, + ) + all_match = False + + return all_match + + class DeepSpeedZeROModifier(FP16OptimizerModifier): def __init__(self, optimizer, **kwargs) -> None: super().__init__(optimizer) @@ -30,19 +90,32 @@ def __init__(self, optimizer, **kwargs) -> None: def can_be_modified(self): import deepspeed + # Note 1: # This modifier relies on the implementation of has_overflow_serial, get_grad_norm_direct, # and has_overflow_partitioned_grads_serial # in https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py. - # Everytime if we want to update this version supporting list to a newer version, - # we need to check if the implementation of these functions are changed. - # An easy way to check is to check the history of this file, if there is no change during the update, + # The minimum version supported is 0.4.0, all versions in between [0.4.0, 0.9.1] + # are manually checked to make sure the implementation of these functions are "logically" not changed. + # The way we did the check is to check the history of this file, if there is no change during the update, # it's safe to update the version supporting list. Otherwise, or the file is moved or renamed, # we need to check the implementation of these functions in detail. + # + # Note 2: + # Since version 0.9.2, we added dynamic source code check, by comparing installed version of code with + # the source code in our code store. If the source code is changed, we will raise a warning to ask user + # to add the new version to the code store. Otherwise, we will override the functions. + ds_version = Version(deepspeed.__version__) - if ds_version > Version("0.9.1") or ds_version < Version("0.4.0"): + if ds_version < Version("0.4.0"): + warnings.warn( + f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}, " + "minimum supported version: 0.4.0, current version", + UserWarning, + ) + return False + if ds_version > Version("0.9.1") and not _dynamic_checks(ds_version, self._optimizer): warnings.warn( - "Skip modifying optimizer because of unsupported DeepSpeed version {}, " - "supported version: 0.4.0 - 0.9.1.".format(deepspeed.__version__), + f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}.", UserWarning, ) return False diff --git a/orttraining/orttraining/python/training/optim/_modifier_registry.py b/orttraining/orttraining/python/training/optim/_modifier_registry.py index 4a3a33ecc051..a88740dac60b 100644 --- a/orttraining/orttraining/python/training/optim/_modifier_registry.py +++ b/orttraining/orttraining/python/training/optim/_modifier_registry.py @@ -3,13 +3,59 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + +import warnings +from typing import ClassVar + from ._apex_amp_modifier import ApexAMPModifier from ._ds_modifier import DeepSpeedZeROModifier from ._megatron_modifier import LegacyMegatronLMModifier +from ._modifier import FP16OptimizerModifier + + +class _AccelerateDeepSpeedZeROModifier(DeepSpeedZeROModifier): + """ + Modifier for wrapper of DeepSpeed Optimizer in accelerator. + https://github.com/huggingface/accelerate/blob/7843286f2e1c50735d259fbc0084a7f1c85e00e3/src/accelerate/utils/deepspeed.py#L182C19-L182C19 + """ + + def __init__(self, accelerator_optimizer, **kwargs) -> None: + super().__init__(accelerator_optimizer.optimizer) + + +def get_full_qualified_type_name(o): + klass = o.__class__ + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ + return module + "." + klass.__qualname__ + + +class OptimizerModifierTypeRegistry: + _MAP: ClassVar[dict[str, FP16OptimizerModifier]] = { + "megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier, + "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, + "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, + "apex.amp.optimizer.unique_name_as_id": ApexAMPModifier, + } + + @staticmethod + def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> FP16OptimizerModifier | None: + """Create modifier for optimizer.""" + if optimizer_full_qualified_name in OptimizerModifierTypeRegistry._MAP: + return OptimizerModifierTypeRegistry._MAP[optimizer_full_qualified_name](optimizer, **kwargs) + + if optimizer_full_qualified_name == "accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper": + if ( + hasattr(optimizer, "optimizer") + and get_full_qualified_type_name(optimizer.optimizer) in OptimizerModifierTypeRegistry._MAP + ): + return _AccelerateDeepSpeedZeROModifier(optimizer, **kwargs) -OptimizerModifierTypeRegistry = { - "megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier, - "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, - "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, - "apex.amp.optimizer.unique_name_as_id": ApexAMPModifier, -} + warnings.warn( + "Skip modifying optimizer because of optimizer name not found in the registry: " + f"{optimizer_full_qualified_name}", + UserWarning, + ) + return None diff --git a/orttraining/orttraining/python/training/optim/fp16_optimizer.py b/orttraining/orttraining/python/training/optim/fp16_optimizer.py index 2a5dfbc2189d..fc93eadc3211 100644 --- a/orttraining/orttraining/python/training/optim/fp16_optimizer.py +++ b/orttraining/orttraining/python/training/optim/fp16_optimizer.py @@ -3,9 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import warnings -from ._modifier_registry import OptimizerModifierTypeRegistry +from ._modifier_registry import OptimizerModifierTypeRegistry, get_full_qualified_type_name def FP16_Optimizer(optimizer, **kwargs): # noqa: N802 @@ -80,22 +79,13 @@ def FP16_Optimizer(optimizer, **kwargs): # noqa: N802 """ - def get_full_qualified_type_name(o): - if hasattr(optimizer, "_amp_stash"): - return "apex.amp.optimizer.unique_name_as_id" - - klass = o.__class__ - module = klass.__module__ - if module == "builtins": - return klass.__qualname__ - return module + "." + klass.__qualname__ - - optimizer_full_qualified_name = get_full_qualified_type_name(optimizer) - if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry: - warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning) - return optimizer - - modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs) - modifier.apply() + optimizer_full_qualified_name = ( + "apex.amp.optimizer.unique_name_as_id" + if hasattr(optimizer, "_amp_stash") + else get_full_qualified_type_name(optimizer) + ) + modifier = OptimizerModifierTypeRegistry.create_modifier(optimizer_full_qualified_name, optimizer, **kwargs) + if modifier is not None: + modifier.apply() return optimizer From 706e13e0c95a730181bca62c348d3283a9194e11 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Wed, 25 Oct 2023 10:46:04 -0700 Subject: [PATCH 08/36] implement affinegrid cpu kernel (#17777) --- docs/OperatorKernels.md | 1 + .../providers/cpu/cpu_execution_provider.cc | 4 + .../core/providers/cpu/tensor/affine_grid.cc | 151 ++++++++++++++++ .../core/providers/cpu/tensor/affine_grid.h | 25 +++ .../providers/cpu/tensor/affine_grid_test.cc | 165 ++++++++++++++++++ .../cpu/tensor/affine_grid_test_gen.py | 111 ++++++++++++ 6 files changed, 457 insertions(+) create mode 100644 onnxruntime/core/providers/cpu/tensor/affine_grid.cc create mode 100644 onnxruntime/core/providers/cpu/tensor/affine_grid.h create mode 100644 onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc create mode 100644 onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ba610515ac28..b3a4cb0c8b4b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -25,6 +25,7 @@ Do not modify directly.* |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| |ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index a54d999a100b..2ca3b1cdf817 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -960,6 +960,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); @@ -2399,6 +2401,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.cc b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc new file mode 100644 index 000000000000..15900ba55398 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/affine_grid.h" + +#include "core/common/common.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/util/math_cpuonly.h" +#include +#include "Eigen/src/Core/Map.h" +#include +#include "core/common/eigen_common_wrapper.h" + +namespace onnxruntime { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + AffineGrid, \ + 20, \ + T, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + AffineGrid); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(double) + +template +void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + + base_grid.resize(static_cast(H * W), 2); + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(j * static_cast(W) + i) << row_vec(i), col_vec(j); + } + } +} + +template +void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast(D), -1, 1); + if (!align_corners) { + slice_vec = slice_vec * (D - 1) / D; + } + + base_grid.resize(static_cast(D * H * W), 3); + for (Eigen::Index k = 0; k < D; k++) { + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(k * static_cast(H * W) + j * static_cast(W) + i) << row_vec(i), col_vec(j), slice_vec(k); + } + } + } +} + +template +void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 2 * 3; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; + const Eigen::Array theta_T(theta_data[2], theta_data[5]); + + auto grid_batch_offset = batch_num * H * W * 2; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(H * W), 2); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 3 * 4; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{ + {theta_data[0], theta_data[1], theta_data[2]}, + {theta_data[4], theta_data[5], theta_data[6]}, + {theta_data[8], theta_data[9], theta_data[10]}}; + const Eigen::Array theta_T(theta_data[3], theta_data[7], theta_data[11]); + + auto grid_batch_offset = batch_num * D * H * W * 3; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(D * H * W), 3); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +Status AffineGrid::Compute(OpKernelContext* context) const { + const Tensor* theta = context->Input(0); + const TensorShape& theta_shape = theta->Shape(); + if (theta_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3"); + } + + const Tensor* size = context->Input(1); + const TensorShape& size_shape = size->Shape(); + const int64_t* size_data = size->Data(); + + if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], H = size_data[2], W = size_data[3]; + + TensorShape grid_shape{N, H, W, 2}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_2d(H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4]; + + TensorShape grid_shape{N, D, H, W, 3}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_3d(D, H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Invalidate size - length of size should be 4 or 5."); + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.h b/onnxruntime/core/providers/cpu/tensor/affine_grid.h new file mode 100644 index 000000000000..5ffe660e986f --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +template +class AffineGrid final : public OpKernel { + public: + AffineGrid(const OpKernelInfo& info) : OpKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + align_corners_ = (align_corners != 0); + } + + Status Compute(OpKernelContext* context) const override; + + private: + bool align_corners_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc b/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc new file mode 100644 index 000000000000..e37e784f2893 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/util/math.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { +TEST(AffineGridTest, 2d) { + OpTester test("AffineGrid", 20); + test.AddInput("theta", {1, 2, 3}, {1.0f, 0.0, 0.0f, 0.0f, 1.0, 0.0f}); + test.AddInput("size", {4}, {1, 1, 2, 3}); + test.AddOutput("grid", {1, 2, 3, 2}, + {-0.6667f, -0.5000f, 0.0000f, -0.5000f, 0.6667f, -0.5000f, -0.6667f, 0.5000f, 0.0000f, 0.5000f, 0.6667f, 0.5000f}); + test.Run(); +} + +// following tests code is generated with: +// python onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py +TEST(AffineGridTest, test_2d_0) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-0.3228f, -0.9151f, 1.1544f, -0.7414f, -0.4386f, -0.5868f, 1.0386f, -0.4132f, -0.5544f, -0.2586f, 0.9228f, -0.0849f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_1) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f, 1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-0.5980f, -0.8620f, 0.3868f, -0.7462f, 1.3716f, -0.6304f, -0.7716f, -0.3696f, 0.2132f, -0.2538f, 1.1980f, -0.1380f, -0.5980f, -0.8620f, 0.3868f, -0.7462f, 1.3716f, -0.6304f, -0.7716f, -0.3696f, 0.2132f, -0.2538f, 1.1980f, -0.1380f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_2) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-0.6726f, -2.7663f, 0.8274f, -1.9003f, -1.2500f, -0.9330f, 0.2500f, -0.0670f, -1.8274f, 0.9003f, -0.3274f, 1.7663f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_3) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f, 1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-1.0670f, -2.4524f, -0.0670f, -1.8750f, 0.9330f, -1.2976f, -1.9330f, 0.2976f, -0.9330f, 0.8750f, 0.0670f, 1.4524f, -1.0670f, -2.4524f, -0.0670f, -1.8750f, 0.9330f, -1.2976f, -1.9330f, 0.2976f, -0.9330f, 0.8750f, 0.0670f, 1.4524f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_4) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-1.0036f, -1.1661f, 1.9509f, -0.8188f, -1.1772f, -0.6736f, 1.7772f, -0.3264f, -1.3509f, -0.1812f, 1.6036f, 0.1661f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_5) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f, 1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-1.0036f, -1.1661f, 0.4736f, -0.9924f, 1.9509f, -0.8188f, -1.3509f, -0.1812f, 0.1264f, -0.0076f, 1.6036f, 0.1661f, -1.0036f, -1.1661f, 0.4736f, -0.9924f, 1.9509f, -0.8188f, -1.3509f, -0.1812f, 0.1264f, -0.0076f, 1.6036f, 0.1661f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_6) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-1.1340f, -4.1160f, 1.8660f, -2.3840f, -2.0000f, -1.3660f, 1.0000f, 0.3660f, -2.8660f, 1.3840f, 0.1340f, 3.1160f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_7) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f, 1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-1.1340f, -4.1160f, 0.3660f, -3.2500f, 1.8660f, -2.3840f, -2.8660f, 1.3840f, -1.3660f, 2.2500f, 0.1340f, 3.1160f, -1.1340f, -4.1160f, 0.3660f, -3.2500f, 1.8660f, -2.3840f, -2.8660f, 1.3840f, -1.3660f, 2.2500f, 0.1340f, 3.1160f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_0) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-0.7468f, -1.3266f, 1.5323f, 0.6627f, -1.2078f, 1.3639f, -0.7468f, 0.6430f, 1.6191f, 0.6627f, 0.7618f, 1.4507f, -0.4048f, -1.5442f, 1.8408f, 1.0048f, -1.4254f, 1.6724f, -0.4048f, 0.4254f, 1.9276f, 1.0048f, 0.5442f, 1.7592f, -0.0627f, -1.7618f, 2.1493f, 1.3468f, -1.6430f, 1.9809f, -0.0627f, 0.2078f, 2.2361f, 1.3468f, 0.3266f, 2.0677f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_1) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f, 1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-0.8962f, -1.4008f, 1.6375f, 0.0435f, -1.3216f, 1.5252f, 0.9832f, -1.2424f, 1.4130f, -0.8962f, 0.5688f, 1.7243f, 0.0435f, 0.6480f, 1.6121f, 0.9832f, 0.7272f, 1.4998f, -0.3832f, -1.7272f, 2.1002f, 0.5565f, -1.6480f, 1.9879f, 1.4962f, -1.5688f, 1.8757f, -0.3832f, 0.2424f, 2.1870f, 0.5565f, 0.3216f, 2.0748f, 1.4962f, 0.4008f, 1.9625f, -0.8962f, -1.4008f, 1.6375f, 0.0435f, -1.3216f, 1.5252f, 0.9832f, -1.2424f, 1.4130f, -0.8962f, 0.5688f, 1.7243f, 0.0435f, 0.6480f, 1.6121f, 0.9832f, 0.7272f, 1.4998f, -0.3832f, -1.7272f, 2.1002f, 0.5565f, -1.6480f, 1.9879f, 1.4962f, -1.5688f, 1.8757f, -0.3832f, 0.2424f, 2.1870f, 0.5565f, 0.3216f, 2.0748f, 1.4962f, 0.4008f, 1.9625f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_2) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-0.5299f, 0.8995f, -4.3568f, -0.2701f, -0.3995f, -2.9818f, -0.5299f, 2.3995f, 0.4064f, -0.2701f, 1.1005f, 1.7814f, -0.6299f, -0.6005f, -2.7691f, -0.3701f, -1.8995f, -1.3941f, -0.6299f, 0.8995f, 1.9941f, -0.3701f, -0.3995f, 3.3691f, -0.7299f, -2.1005f, -1.1814f, -0.4701f, -3.3995f, 0.1936f, -0.7299f, -0.6005f, 3.5818f, -0.4701f, -1.8995f, 4.9568f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_3) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f, 0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-0.5982f, 0.7410f, -4.1890f, -0.4250f, -0.1250f, -3.2724f, -0.2518f, -0.9910f, -2.3557f, -0.5982f, 2.2410f, 0.5741f, -0.4250f, 1.3750f, 1.4908f, -0.2518f, 0.5090f, 2.4075f, -0.7482f, -1.5090f, -1.8075f, -0.5750f, -2.3750f, -0.8908f, -0.4018f, -3.2410f, 0.0259f, -0.7482f, -0.0090f, 2.9557f, -0.5750f, -0.8750f, 3.8724f, -0.4018f, -1.7410f, 4.7890f, -0.5982f, 0.7410f, -4.1890f, -0.4250f, -0.1250f, -3.2724f, -0.2518f, -0.9910f, -2.3557f, -0.5982f, 2.2410f, 0.5741f, -0.4250f, 1.3750f, 1.4908f, -0.2518f, 0.5090f, 2.4075f, -0.7482f, -1.5090f, -1.8075f, -0.5750f, -2.3750f, -0.8908f, -0.4018f, -3.2410f, 0.0259f, -0.7482f, -0.0090f, 2.9557f, -0.5750f, -0.8750f, 3.8724f, -0.4018f, -1.7410f, 4.7890f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_4) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-1.6226f, -2.2620f, 1.4189f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, 1.1965f, 1.9147f, 1.2557f, -1.1095f, -2.5884f, 1.8816f, 1.7095f, -2.3508f, 1.5448f, -1.1095f, 1.3508f, 2.0552f, 1.7095f, 1.5884f, 1.7184f, -0.5965f, -2.9147f, 2.3443f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 2.2226f, 1.2620f, 2.1811f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_5) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f, 1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-1.6226f, -2.2620f, 1.4189f, -0.2130f, -2.1433f, 1.2505f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, -0.2130f, 1.7960f, 1.4241f, 1.1965f, 1.9147f, 1.2557f, -0.5965f, -2.9147f, 2.3443f, 0.8130f, -2.7960f, 2.1759f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 0.8130f, 1.1433f, 2.3495f, 2.2226f, 1.2620f, 2.1811f, -1.6226f, -2.2620f, 1.4189f, -0.2130f, -2.1433f, 1.2505f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, -0.2130f, 1.7960f, 1.4241f, 1.1965f, 1.9147f, 1.2557f, -0.5965f, -2.9147f, 2.3443f, 0.8130f, -2.7960f, 2.1759f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 0.8130f, 1.1433f, 2.3495f, 2.2226f, 1.2620f, 2.1811f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_6) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-0.6098f, 1.5490f, -8.2197f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.0902f, 1.9510f, 4.0566f, -0.7598f, -0.7010f, -5.8381f, -0.2402f, -3.2990f, -3.0881f, -0.7598f, 2.2990f, 3.6881f, -0.2402f, -0.2990f, 6.4381f, -0.9098f, -2.9510f, -3.4566f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.3902f, -2.5490f, 8.8197f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_7) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f, 0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-0.6098f, 1.5490f, -8.2197f, -0.3500f, 0.2500f, -6.8447f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.3500f, 3.2500f, 2.6816f, -0.0902f, 1.9510f, 4.0566f, -0.9098f, -2.9510f, -3.4566f, -0.6500f, -4.2500f, -2.0816f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.6500f, -1.2500f, 7.4447f, -0.3902f, -2.5490f, 8.8197f, -0.6098f, 1.5490f, -8.2197f, -0.3500f, 0.2500f, -6.8447f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.3500f, 3.2500f, 2.6816f, -0.0902f, 1.9510f, 4.0566f, -0.9098f, -2.9510f, -3.4566f, -0.6500f, -4.2500f, -2.0816f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.6500f, -1.2500f, 7.4447f, -0.3902f, -2.5490f, 8.8197f}); + test.Run(); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py new file mode 100644 index 000000000000..22bad6f1be53 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py @@ -0,0 +1,111 @@ +import argparse + +import numpy as np +import torch +from torch.nn.functional import affine_grid + +opset_version = 20 +parser = argparse.ArgumentParser(description="Generate test cases for the AffineGrid operator.") +parser.add_argument("--dim", type=int, choices=[2, 3], help="Dimension of the test cases (2 or 3)") +args = parser.parse_args() + +if args.dim is None or args.dim == 2: + align_corners_options = [False, True] + angles = [10, 60] + translations = [np.array([0.3, -0.5]), np.array([-0.5, -0.5])] + scales = [np.array([1.5, 0.5]), np.array([3.0, 5.5])] + sizes = [[1, 1, 3, 2], [2, 10, 2, 3]] + test_count = 0 + + for align_corners in align_corners_options: + for angle, translation, scale in zip(angles, translations, scales): + for size in sizes: + theta = np.array([], dtype=np.float32) + for _ in range(size[0]): + angle_radian = (angle / 180.0) * np.pi + theta = np.append( + theta, + [ + np.cos(angle_radian) * scale[0], + -np.sin(angle_radian), + translation[0], + np.sin(angle_radian), + np.cos(angle_radian) * scale[1], + translation[1], + ], + ) + theta = theta.reshape(size[0], 2, 3) + theta = torch.Tensor(theta) + grid = affine_grid(theta, size, align_corners=align_corners) + + # Print the C++ code for the test case + print(f"TEST(AffineGridTest, test_2d_{test_count}) {{") + print(f' OpTester test("AffineGrid", {opset_version});') + print(f' test.AddAttribute("align_corners", (int64_t){1 if align_corners else 0});') + print( + f" test.AddInput(\"theta\", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{', '.join([f'{x:.6f}f' for x in theta.flatten()])}}});" + ) + print( + f' test.AddInput("size", {{{len(size)}}}, {{{size[0]}, {size[1]}, {size[2]}, {size[3]}}});' + ) + print( + f" test.AddOutput(\"grid\", {{{size[0]}, {size[2]}, {size[3]}, 2}}, {{{', '.join([f'{x:.4f}f' for x in grid.flatten()])}}});" + ) + print(" test.Run();") + print("}\n") + test_count += 1 + + +if args.dim is None or args.dim == 3: + align_corners_options = [False, True] + angles = [[10, 20], [60, -30]] + translations = [np.array([0.3, -0.5, 1.8]), np.array([-0.5, -0.5, 0.3])] + scales = [np.array([1.5, 2.0, 0.5]), np.array([0.3, 3.0, 5.5])] + sizes = [[1, 1, 3, 2, 2], [2, 10, 2, 2, 3]] + test_count = 0 + + for align_corners in align_corners_options: + for angle, translation, scale in zip(angles, translations, scales): + for size in sizes: + theta = np.array([], dtype=np.float32) + for _ in range(size[0]): + angle_radian_x = (angle[0] / 180.0) * np.pi + angle_radian_y = (angle[1] / 180.0) * np.pi + rot_matrix_x = np.array( + [ + [1, 0, 0], + [0, np.cos(angle_radian_x), -np.sin(angle_radian_x)], + [0, np.sin(angle_radian_x), np.cos(angle_radian_x)], + ] + ) + rot_matrix_y = np.array( + [ + [np.cos(angle_radian_y), 0, np.sin(angle_radian_y)], + [0, 1, 0], + [-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)], + ] + ) + rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y) + rot_matrix = rot_matrix * scale.reshape(3, 1) + rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1) + theta = np.append(theta, rot_matrix.flatten()) + theta = theta.reshape(size[0], 3, 4) + theta = torch.Tensor(theta) + grid = affine_grid(theta, size, align_corners=align_corners) + + # Print the C++ code for the test case + print(f"TEST(AffineGridTest, test_3d_{test_count}) {{") + print(f' OpTester test("AffineGrid", {opset_version});') + print(f' test.AddAttribute("align_corners", (int64_t){1 if align_corners else 0});') + print( + f" test.AddInput(\"theta\", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{', '.join([f'{x:.6f}f' for x in theta.flatten()])}}});" + ) + print( + f' test.AddInput("size", {{{len(size)}}}, {{{size[0]}, {size[1]}, {size[2]}, {size[3]}, {size[4]}}});' + ) + print( + f" test.AddOutput(\"grid\", {{{size[0]}, {size[2]}, {size[3]}, {size[4]}, 3}}, {{{', '.join([f'{x:.4f}f' for x in grid.flatten()])}}});" + ) + print(" test.Run();") + print("}\n") + test_count += 1 From d88d52eeadcdc2813e478dbb2d8dc6f5575258d6 Mon Sep 17 00:00:00 2001 From: snadampal <87143774+snadampal@users.noreply.github.com> Date: Wed, 25 Oct 2023 13:34:57 -0500 Subject: [PATCH 09/36] [aarch64] Remove mmla kernel support from apple (#18082) ### Description The mmla kernels require additional ISA flags and are currently supported only on Linux ### Motivation and Context more context is in https://github.com/microsoft/onnxruntime/pull/15270 cc: @skottmckay , @chenfucn , @snnn --- cmake/onnxruntime_mlas.cmake | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 06237c8010fb..a62b1b259d10 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -325,9 +325,7 @@ else() ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S @@ -336,24 +334,26 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ) if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) From d30d4d372a33640aed78eb4bfe05db31e36f6e2a Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 25 Oct 2023 15:34:58 -0700 Subject: [PATCH 10/36] Add MatMul FP4 and NF4 Support (#18066) ### Description Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to support quantization on weight. This PR adds: - schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating point) and NF4 (4-bit NormalFloat) quantization on weight. - a naive implementation for MatMulBnb4 on CPU and GPU, i.e., implemented like MatMul(A, Dequantize(B)). - a special implementation for GemV for MatMulBnb4 and related benchmark tool. - tool to quantize model to FP4 or NF4. --- cmake/onnxruntime_rocm_hipify.cmake | 5 + docs/ContribOperators.md | 57 +++++ docs/OperatorKernels.md | 2 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../quantization/blockwise_quant_block_bnb4.h | 202 +++++++++++++++ .../quantization/dequantize_blockwise_bnb4.h | 143 +++++++++++ .../cpu/quantization/matmul_bnb4.cc | 109 ++++++++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../quantization/dequantize_blockwise_bnb4.cu | 129 ++++++++++ .../dequantize_blockwise_bnb4.cuh | 26 ++ .../cuda/quantization/matmul_bnb4.cc | 144 +++++++++++ .../cuda/quantization/matmul_bnb4.cu | 192 ++++++++++++++ .../cuda/quantization/matmul_bnb4.cuh | 26 ++ .../core/graph/contrib_ops/contrib_defs.cc | 35 +++ .../python/onnxruntime_pybind_quant.cc | 31 +++ .../kernels/cuda/dequant_blockwise_bnb4.cu | 89 +++++++ .../kernels/cuda/matmul_bnb4.cu | 96 +++++++ .../kernels/dequantize_blockwise_bnb4.py | 92 +++++++ .../kernel_explorer/kernels/matmul_bnb4.py | 136 ++++++++++ .../quantization/matmul_bnb4_quantizer.py | 240 ++++++++++++++++++ .../test/contrib_ops/matmul_bnb4_test.cc | 151 +++++++++++ .../quantization/test_op_matmul_bnb4.py | 186 ++++++++++++++ .../test_quantizeblockwise_bnb4.py | 139 ++++++++++ 23 files changed, 2236 insertions(+) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc create mode 100644 onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu create mode 100644 onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py create mode 100644 onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py create mode 100644 onnxruntime/test/contrib_ops/matmul_bnb4_test.cc create mode 100644 onnxruntime/test/python/quantization/test_op_matmul_bnb4.py create mode 100644 onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index fe3e577b4fc3..de1458c12001 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -54,6 +54,11 @@ set(contrib_ops_excluded_files "quantization/attention_quantization_impl.cuh" "quantization/dequantize_blockwise.cuh" "quantization/dequantize_blockwise.cu" + "quantization/dequantize_blockwise_bnb4.cuh" + "quantization/dequantize_blockwise_bnb4.cu" + "quantization/matmul_bnb4.cc" + "quantization/matmul_bnb4.cuh" + "quantization/matmul_bnb4.cu" "quantization/matmul_nbits.cc" "quantization/matmul_nbits.cuh" "quantization/matmul_nbits.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5805333a0868..1a76c18a6a8e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -47,6 +47,7 @@ Do not modify directly.* * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention + * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat @@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulBnb4** + + MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
block_size : int (required)
+
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+
quant_type : int (required)
+
quantization data type. 0 for FP4, 1 for NF4.
+
+ +#### Inputs + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional quantized data for weight
+
absmax : T1
+
quantization constants
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MatMulFpQ4** Matrix product with right hand matrix being pre-packed and quantized int4 data blob. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b3a4cb0c8b4b..84249df92231 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -457,6 +457,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| @@ -852,6 +853,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f77e403f26dd..f9d9b13f0fed 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h new file mode 100644 index 000000000000..cb8e97a592d8 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +typedef enum Bnb_DataType_t { + FP4 = 0, + NF4 = 1, +} Bnb_DataType_t; + +FORCEINLINE uint8_t QuantizeOneFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + uint8_t sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) { + if (x > 0.583333f) { + if (x > 0.8333333f) { + return 0b0011 + sign; + } else { + return 0b0010 + sign; + } + } else if (x > 0.4166667f) { + return 0b101 + sign; + } else { + return 0b100 + sign; + } + } else if (x > 0.0859375f) { + if (x > 0.20833333f) { + return 0b0111 + sign; + } else { + return 0b0110 + sign; + } + } else if (x > 0.00260417f) { + return 0b0001 + sign; + } else { + return 0b0000 + sign; + } +} + +FORCEINLINE uint8_t QuantizeOneNF4(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { // 1 + if (x > 0.6427869200706482f) { // 11 + if (x > 0.8614784181118011f) { // 111 + return 0b1111; + } else { + return 0b1110; + } + } else if (x > 0.5016634166240692f) { // 110 + return 0b1101; + } else { + return 0b1100; + } + } else if (x > 0.2035212516784668f) { // 10 + if (x > 0.2920137718319893f) { // 101 + return 0b1011; + } else { + return 0b1010; + } + } else if (x > 0.1202552504837513f) { // 100 + return 0b1001; + } else { + return 0b1000; + } + } else if (x > -0.33967943489551544f) { // 0 + if (x > -0.13791173323988914f) { // 01 + if (x > -0.045525018125772476f) { // 011 + return 0b0111; + } else { + return 0b0110; + } + } else if (x > -0.23460740596055984f) { // 010 + return 0b0101; + } else { + return 0b0100; + } + } else if (x > -0.6106329262256622f) { // 00 + if (x > -0.4599952697753906f) { // 001 + return 0b0011; + } else { + return 0b0010; + } + } else if (x > -0.8480964004993439f) { // 000 + return 0b0001; + } else { + return 0b0000; + } +} + +template +FORCEINLINE uint8_t QuantizeOneBnb4(float x) { + if constexpr (DATA_TYPE == FP4) + return QuantizeOneFP4(x); + else + return QuantizeOneNF4(x); +} + +template +FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { + float local_absmax = 0.0f; + + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size; + int32_t dst_offset = block_idx * block_size / 2; + + for (int32_t idx = 0; idx < block_len; idx++) { + const float v = static_cast(src[src_offset + idx]); + local_absmax = fmaxf(local_absmax, fabsf(v)); + } + + absmax_block = static_cast(local_absmax); + const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax; + const uint8_t vi0 = QuantizeOneBnb4(v0); + + const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0; + const uint8_t vi1 = QuantizeOneBnb4(v1); + + dst[dst_offset + idx / 2] = (vi0 << 4) | vi1; + } +} + +static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, + 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, + -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f, + -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f}; + +static float nf4_qaunt_map[16] = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +template +FORCEINLINE T DequantizeOneBnb4(uint8_t x) { + if constexpr (DATA_TYPE == FP4) + return static_cast(fp4_qaunt_map[x]); + else + return static_cast(nf4_qaunt_map[x]); +} + +template +FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size / 2; + int32_t dst_offset = block_idx * block_size; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const uint8_t val = src[src_offset + idx / 2]; + + dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block; + if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block; + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h new file mode 100644 index 000000000000..5ddb77e5b5ee --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block_bnb4.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + QuantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + QuantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + QuantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + QuantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + QuantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + QuantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwiseBn4DataTyped + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + DequantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + DequantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + DequantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + DequantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + DequantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + DequantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwiseBn4DataTyped + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc new file mode 100644 index 000000000000..2f3ede49c365 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise_bnb4.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulBnb4 final : public OpKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +Status MatMulBnb4::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const float* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const float* absmax_data = absmax->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwiseBnb4( + tmp_b_data_ptr.get(), + b_quant_data, + absmax_data, + static_cast(block_size_), + static_cast(quant_type_), + static_cast(N_), + static_cast(K_), + thread_pool); + + constexpr bool transa = false; + constexpr bool transb = true; + TensorShape b_shape({N_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(transa); + const size_t ldb = helper.Ldb(transb); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index c52f869d6a9d..e762a80cb0e2 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -118,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); @@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu new file mode 100644 index 000000000000..e58723f0b31e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + T host_quant_map[16]; + switch (quant_type) { + case FP4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]); + break; + case NF4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]); + break; + } + CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream)); + + return Status::OK(); +} + +template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); + +template +__global__ void kDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + const int block_size, + const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * 2]; + uint8_t qvals[NUM_PER_TH]; + T local_abs_max = T(0.0f); + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); + + #pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; + vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; + #else + // half multiplication not supported + vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); + vals[j * 2 + 1] = + static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); + #endif + } + + __syncthreads(); + StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store); + } +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + int tile_size = 1024; + kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>( + quant_map, + output, + quant_data, + absmax, + block_size / 2, + numel); + + return Status::OK(); +} + +template Status DequantizeBnb4( + const float* quant_map, + float* output, + const uint8_t* quant_data, + const float* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template Status DequantizeBnb4( + const half* quant_map, + half* output, + const uint8_t* quant_data, + const half *absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh new file mode 100644 index 000000000000..4aef3ab699f9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc new file mode 100644 index 000000000000..bd5b6e0a8a1c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "matmul_bnb4.cuh" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulBnb4 final : public CudaKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +template +Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const auto* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const auto* absmax_data = absmax->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + // TODO: find a better way to create the quant_map without using a buffer + // don't want to use malloc directly so asking from the caller + // can create a __device__ static array for float but doesn't work for half + IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + auto* quant_map_buffer_data = quant_map_buffer.get(); + ORT_RETURN_IF_ERROR(SetBnbQuantMap( + SafeInt(quant_type_), + reinterpret_cast(quant_map_buffer_data), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (!is_4bit_done) { + IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + auto* b_dequant_data = b_dequant_ptr.get(); + ORT_RETURN_IF_ERROR(DequantizeBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(b_dequant_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(block_size_), + SafeInt(N_ * K_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_dequant_data), + SafeInt(K_), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu new file mode 100644 index 000000000000..1d9aa75ff370 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include +#include "matmul_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define num_values_4bit 32 +template +__global__ void kgemm_4bit_inference_naive( + int M, + int N, + int K, + const T* __restrict__ A, + const uint8_t* B, + const T* absmax, + const T* datatype, + T* out, + int lda, + int ldb, + int ldc, + int block_size) { + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + uint8_t local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + int inner_idx_halved = inner_idx / 2; + int offset_B = ldb * row_B; + int absidx = ((2 * offset_B) + inner_idx) / block_size; + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + #else + // half multiplication not supported + local_B[k * 2] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * + static_cast(local_absmax)); + local_B[k * 2 + 1] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * + static_cast(local_absmax)); + #endif + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + } else { + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + } + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_C += static_cast(local_A[k] * local_B[k]); + #else + // half multiplication not supported + local_C += static_cast(local_A[k]) * static_cast(local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + + int lda = k; + int ldb = (k + 1) / 2; + int ldc = n; + int num_blocks = (n + 3) / 4; + + constexpr int bits = std::is_same_v ? 16 : 32; + kgemm_4bit_inference_naive<<>>( + m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); + + return true; +} + +template bool TryMatMulBnb4( + const float* quant_map, + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template bool TryMatMulBnb4( + const half* quant_map, + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh new file mode 100644 index 000000000000..743234282fbf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5e5eee568fa2..681a728f823d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3239,6 +3239,41 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); }); + static const char* MatMulBnb4_ver1_doc = R"DOC( +MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + +Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. +Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulBnb4_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional quantized data for weight", "T2") + .Input(2, "absmax", "quantization constants", "T1") + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 52ea677d5141..04dfa9b51e11 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -6,6 +6,7 @@ #include #include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" #include "core/util/thread_utils.h" namespace pybind11 { @@ -64,9 +65,39 @@ void QuantizeMatMul4BitsBlockwise( tp.get()); } +template +void QuantizeMatMulBnb4Blockwise( + py::array_t dst, + py::array_t src, + py::array_t absmax, + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info absmax_buf = absmax.request(); + + contrib::QuantizeBlockwiseBnb4( + static_cast(dst_buf.ptr), + static_cast(src_buf.ptr), + static_cast(absmax_buf.ptr), + block_size, + quant_type, + N, + K, + tp.get()); +} + void CreateQuantPybindModule(py::module& m) { m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); } } // namespace python diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu new file mode 100644 index 000000000000..3504ce1bebe8 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct DequantizeBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const uint8_t* quant_; + const T* absmax_; + T* quant_map_buffer_; + int n_; + int k_; +}; + +template +class DequantizeBnb4 : public IKernelExplorer { + public: + DequantizeBnb4( + int quant_type, + DeviceArray& output, + DeviceArray& quant, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.quant_type_ = quant_type; + params_.output_ = static_cast(output.ptr()); + params_.quant_ = static_cast(quant.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.quant_, + params_.absmax_, + 64, + params_.n_ * params_.k_, + params_.StreamHandle())); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = DequantizeBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(DequantizeBnb4, half); + REGISTER_OP(DequantizeBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu new file mode 100644 index 000000000000..e4cd83565357 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" +#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct MatrixFloatBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const T* a_; + const uint8_t* b_; + const T* absmax_; + T* quant_map_buffer_; + int m_; + int n_; + int k_; +}; + +template +class MatrixFloatBnb4 : public IKernelExplorer { + public: + MatrixFloatBnb4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int quant_type, int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.quant_type_ = quant_type; + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + contrib::cuda::TryMatMulBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.a_, + params_.b_, + params_.absmax_, + params_.m_, + params_.n_, + params_.k_, + 64, + params_.StreamHandle()); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = MatrixFloatBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(MatrixFloatBnb4, half); + REGISTER_OP(MatrixFloatBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py new file mode 100644 index 000000000000..140151aadcc0 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class DequantizeBnb4Metric(ke.BandwidthMetric): + quant_type: str + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}" + ) + + +def profile_dequantize_int4_func(qt, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(n, k).astype(dtype) + quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + quant_d = ke.DeviceArray(quant) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k) + duration_ms = my_op.Profile() + total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype) + + ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k)) + + +def profile_with_args(qt, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_dequantize_int4_func(qt, n, k, dtype, func) + + +def profile(): + for qt in quant_types: + for dt in dtypes: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, n, k, dt, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py new file mode 100644 index 000000000000..4a9489050fd6 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +def dtype_to_funcs_cublas(dtype): + type_map = { + "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), + "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class MatrixMulMetric(ke.BandwidthMetric): + m: int + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +@dataclass +class MatrixFpBnb4Metric(MatrixMulMetric): + quant_type: str + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + + my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt)) + + +def profile_gemm_func(m, n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.rand(k, n).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + f = getattr(ke, func) + my_op = f(output_d, a_d, b_d, m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) + + +def profile_with_args(qt, m, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func) + + for func in dtype_to_funcs_cublas(dtype): + profile_gemm_func(m, n, k, dtype, func) + + +def profile(): + dims_m = [1] + for qt in quant_types: + for dt in dtypes: + for m in dims_m: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, m, n, k, dt, False) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py new file mode 100644 index 000000000000..951746a08930 --- /dev/null +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx +from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto + +from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + +from .onnx_model import ONNXModel +from .quant_utils import attribute_to_kwarg + +logger = logging.getLogger(__name__) + + +class MatMulBnb4Quantizer: + """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type""" + + ################## + # quantization types, must be consistent with native code type + # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h + + # 4b floating point with bias of 3 + FP4 = 0 + + # 4b NormalFloat + NF4 = 1 + + def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None): + nodes_to_exclude = nodes_to_exclude or [] + assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4] + self.model = ONNXModel(model) + self.quant_type = quant_type + self.block_size = block_size + self.nodes_to_exclude = set(nodes_to_exclude) + + @staticmethod + def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray: + """4b quantize fp32/fp16 weight""" + + if len(fpweight.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + # need to copy since the transposed weight still has the original memory layout + # Linear4bit quantizes its weight data which is the transposed weight + fpweight_t = fpweight.transpose().copy() + + rows, cols = fpweight.shape + numel = rows * cols + block_size = self.block_size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=fpweight.dtype) + # block wise quantization, fpweight_t is flattened and divided into blocks + quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows) + + return (packed, absmax) + + def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + + if node.op_type != "MatMul": + return node # only care about MatMul for now + + logger.debug(f"start to quantize {node.name} ...") + if node.name in self.nodes_to_exclude: + logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + return node + + inputB = node.input[1] # noqa: N806 + B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + if B is None: + logger.debug("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + B_array = onnx.numpy_helper.to_array(B) # noqa: N806 + if len(B_array.shape) != 2: + logger.debug("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + + packed, absmax = self.bnb4_block_quant(B_array) + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Bnb4" + for input in Bs_graph.input: + if input.name == inputB: + Bs_graph.input.remove(input) + break + + absmax_tensor = onnx.numpy_helper.from_array(absmax) + absmax_tensor.name = B.name + "_absmax" + + Bs_graph.initializer.extend([B_quant, absmax_tensor]) + + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["block_size"] = self.block_size + kwargs["quant_type"] = self.quant_type + + matmul_bnb4_node = onnx.helper.make_node( + "MatMulBnb4", + inputs=[node.input[0], B_quant.name, absmax_tensor.name], + outputs=[node.output[0]], + name=node.name + "_Bnb4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.debug(f"complete quantization of {node.name} ...") + + return matmul_bnb4_node + + def _process_subgraph(self, graph_stack: List[GraphProto]): + new_nodes = [] + graph = graph_stack[-1] + + for node in graph.node: + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(graph_attrs): + kwargs = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # recursive call to take care of sub-graph + graph_stack.append(attr.g) + kv = {attr.name: self._process_subgraph(graph_stack)} + elif attr.type == onnx.AttributeProto.GRAPHS: + value = [] + for subgraph in attr.graphs: + # recursive call to take care of sub-graph + graph_stack.append(subgraph) + value.extend([self._process_subgraph(graph_stack)]) + kv = {attr.name: value} + else: + kv = attribute_to_kwarg(attr) + kwargs.update(kv) + node = onnx.helper.make_node( # noqa: PLW2901 + node.op_type, node.input, node.output, name=node.name, **kwargs + ) + + new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack)) + + graph.ClearField("node") + graph.node.extend(new_nodes) + graph_stack.pop() + return graph + + def process(self): + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + + +def parse_args(): + parser = argparse.ArgumentParser( + description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices. + +A weight matrix is partitioned into blocks, where each block is a contiguous +subset inside the flattened transposed weight matrix. Each block is quantized +into a set of 4b integers with an absolute value scaling factor. +""" + ) + + parser.add_argument("--input_model", required=True, help="Path to the input model file") + parser.add_argument("--output_model", required=True, help="Path to the output model file") + parser.add_argument( + "--quant_type", + required=False, + default=1, + options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], + help="Quantization data type. 0: FP4, 1: NF4", + ) + parser.add_argument( + "--block_size", + required=False, + default=64, + description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", + ) + parser.add_argument("-v", "--verbose", required=False, action="store_true") + parser.set_defaults(verbose=False) + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + input_model_path = args.input_model + output_model_path = args.output_model + + if os.path.exists(output_model_path): + logger.error(f"file {output_model_path} already exists") + raise Exception(f"file {output_model_path} already exists") + + model = onnx.load(input_model_path) + quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude) + quant.process() + quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc new file mode 100644 index 000000000000..e739b17d5885 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +void QuantizeDequantizeBnb4(std::vector& raw_vals, // N X K + std::vector& quant_vals, + std::vector& absmax, + int32_t quant_type, + int32_t N, + int32_t K, + int32_t block_size) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + contrib::QuantizeBlockwiseBnb4( + quant_vals.data(), + raw_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); + + contrib::DequantizeBlockwiseBnb4( + raw_vals.data(), + quant_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); +} + +void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) { + RandomValueGenerator random{1234}; + std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); + // quantizer expects transposed weights, N X K + std::vector input1_f_vals(random.Gaussian(std::vector({N, K}), 0.0f, 0.25f)); + + int64_t numel = N * K; + int64_t quantized_numel = (numel + 1) / 2; + int64_t total_block_count = (numel + block_size - 1) / block_size; + std::vector input1_vals(quantized_numel); + std::vector absmax(total_block_count); + + QuantizeDequantizeBnb4(input1_f_vals, + input1_vals, + absmax, + static_cast(quant_type), + static_cast(N), + static_cast(K), + static_cast(block_size)); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulBnb4", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("quant_type", quant_type); + if (use_float16) { + test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, ToFloat16(absmax), true); + + test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); + test.SetOutputAbsErr("Y", 0.02f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, absmax, true); + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); + } +} + +TEST(MatMulBnb4, Float32) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, false); + } + } + } + } + } +} + +#if defined(USE_CUDA) +TEST(MatMulBnb4, Float16) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, true); + } + } + } + } + } +} + +#endif +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py new file mode 100644 index 000000000000..88432d75c653 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +import onnx +from onnx import TensorProto, helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import quant_utils + +quant_maps = { + 0: [ + 0.00000000, + 5.208333333e-03, + 0.66666667, + 1.00000000, + 0.33333333, + 0.50000000, + 0.16666667, + 0.25000000, + -0.00000000, + -5.208333333e-03, + -0.66666667, + -1.00000000, + -0.33333333, + -0.50000000, + -0.16666667, + -0.25000000, + ], + 1: [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], +} + + +class TestOpMatMulBnb4(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulbnb4.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray: + rows, cols = shape + line = np.zeros(shape) + line = line.reshape(-1) + quant_map = np.array(quant_maps[quant_type], dtype=np.float32) + + v = 0 + for i in range(line.shape[0]): + line[i] = quant_map[v] + v += 1 + if v >= 16: + v = 0 + + # bnb quantization quantizes weight.T after flattening + line = line.reshape(cols, rows).transpose() + return line.reshape(shape) + + def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, quant_type: int) -> None: + # (input) + # | + # MatMul + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + ) + + # for this to work (in_features * out_features) % block_size == 0 + in_features = 52 + out_features = 288 + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + ) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_bnb4_test" + graph = helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test(self, quant_type: int, block_size: int): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_{quant_type}.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, quant_type) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + + model_bnb4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulBnb4_{quant_type}_{block_size}.onnx").absolute() + ) + + # Quantize fp32 model to bnb4 model + from onnxruntime.quantization import matmul_bnb4_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_bnb4_quantizer.MatMulBnb4Quantizer(model, quant_type, block_size) + quant.process() + quant.model.save_model_to_file(model_bnb4_path, False) + + quant_nodes = {"MatMulBnb4": 1} + check_op_type_count(self, model_bnb4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_bnb4_path, data_reader.get_next()) + except Exception as exception: + raise exception + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_fp4(self): + np.random.seed(13) + self.quant_test(0, 64) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_nf4(self): + np.random.seed(13) + self.quant_test(1, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py new file mode 100644 index 000000000000..9e9d05fae027 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from importlib.util import find_spec + +import numpy as np +import numpy.typing as npt + +quant_enums = {"FP4": 0, "NF4": 1} + + +def quantize_block_fp4(block: npt.ArrayLike): + # quantize a block of float32 values to uint8 by simulating a binary search using pivots + # could have used (block[:,None] - quant_map).argmin(axis=1) but there are some mismatches due to + # floating point precision + # block: 1-D array of normalized [-1,1] float32 values, len(block) % 2 == 0 + + # pivots to find the quantization index + # only half of the pivots are needed since the other half is symmetric + pivots = np.array( + [0.00260417, 0.0859375, 0.20833333, 0.29166667, 0.4166667, 0.583333, 0.8333333, 1], dtype=np.float32 + ) + # indices are not 0,1,2,3,4,5,6,7 because it is a floating point data type + pivot_indices = np.array([0, 1, 6, 7, 4, 5, 2, 3], dtype=np.uint8) + + # signs of the block + signs = (block < 0).astype(np.uint8) * 8 + + # find the uint8 quantization index + # argmax finds the first occurrence of True + quant_indices = pivot_indices[(np.abs(block)[:, None] <= pivots).argmax(axis=1)] + signs + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_block_nf4(block: npt.ArrayLike): + pivots = np.array( + [ + -0.8480964004993439, + -0.6106329262256622, + -0.4599952697753906, + -0.33967943489551544, + -0.23460740596055984, + -0.13791173323988914, + -0.045525018125772476, + 0.03979014977812767, + 0.1202552504837513, + 0.2035212516784668, + 0.2920137718319893, + 0.3893125355243683, + 0.5016634166240692, + 0.6427869200706482, + 0.8614784181118011, + 1.0, + ], + dtype=np.float32, + ) + + quant_indices = (block[:, None] <= pivots).argmax(axis=1).astype(np.uint8) + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, quant_type: str, target=None): + if len(matrix_float.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + + numel = matrix_float.size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype=np.uint8) + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + + flattened_matrix_float = matrix_float.flatten() + for block_idx in range(num_blocks): + block_len = min(block_size, numel - block_idx * block_size) + block = np.float32(flattened_matrix_float[block_idx * block_size : block_idx * block_size + block_len]) + + block_absmax = np.max(np.abs(block)) + reciprocal_absmax = 1.0 / block_absmax if block_absmax != 0 else 0.0 + absmax[block_idx] = block_absmax + + if block_len % 2 != 0: + block = np.append(block, 0.0) + block_len += 1 + + block *= reciprocal_absmax + start = block_idx * block_size // 2 + end = start + block_len // 2 + if quant_type == "FP4": + packed[start:end] = quantize_block_fp4(block) + else: + packed[start:end] = quantize_block_nf4(block) + + return (packed, absmax) + + +def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int, quant_type: str): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + quant_type_enum = quant_enums[quant_type] + + n, k = matrix_float.shape # already transposed + numel = n * k + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + + quantize_matmul_bnb4(packed, matrix_float, absmax, block_size, quant_type_enum, n, k) + return (packed, absmax) + + +class TestQuantizeBlockwiseBnb4(unittest.TestCase): + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_blockwise_bnb4(self): + for quant_type in ["FP4", "NF4"]: + for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type) + quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type) + quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type) + assert np.allclose(quant_value_ref, quant_value) + assert np.allclose(absmax_ref, absmax) + + +if __name__ == "__main__": + unittest.main() From 538e97cbda5e8c6c7c3a8796ce483da185f2e91c Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 25 Oct 2023 19:56:16 -0700 Subject: [PATCH 11/36] [DML EP] Add dynamic graph compilation (#17876) Historically, DML was only able to fuse partitions when all sizes are known in advance or when we were overriding them at session creation time. But in practice, it should be possible to compile partitions at compute time if the caller knows that the dimensions won't be changed for every inference (e.g. resizing a webcam window, or padding the input to powers of 2). This graph will be cached and reused until the sizes change. This is an opt-in option gated under the `enable_dynamic_graph_fusion` option, which means that it will only be enabled when the caller requests it since they have more context on how their model will be called between inferences. This PR also adds the option to disable metacommands from the python API, which is an option for the C API but was lacking for python. --- .../core/providers/dml/dml_provider_factory.h | 2 +- .../inc/DmlExecutionProvider.h | 6 +- .../inc/IWinmlExecutionProvider.h | 5 + .../src/AbiCustomRegistry.cpp | 8 +- .../DmlExecutionProvider/src/DmlEdgeShapes.h | 42 ++ .../src/DmlGraphFusionHelper.cpp | 168 +++++++- .../src/DmlGraphFusionHelper.h | 9 + .../src/DmlGraphFusionTransformer.cpp | 63 ++- .../src/DmlRuntimeFusedGraphKernel.cpp | 369 ++++++++++++++++++ .../src/DmlRuntimeFusedGraphKernel.h | 21 + .../src/DmlRuntimeGraphFusionTransformer.cpp | 161 ++++++++ .../src/DmlRuntimeGraphFusionTransformer.h | 42 ++ .../src/ExecutionProvider.cpp | 21 +- .../src/ExecutionProvider.h | 13 +- .../src/GraphDescBuilder.cpp | 67 +++- .../src/GraphDescBuilder.h | 15 +- .../src/GraphPartitioner.cpp | 84 ++-- .../src/GraphPartitioner.h | 4 +- .../src/IExecutionProvider.h | 7 + .../src/MLOperatorAuthorImpl.cpp | 3 +- .../src/MLOperatorAuthorImpl.h | 38 +- .../providers/dml/dml_provider_factory.cc | 84 ++-- .../dml/dml_provider_factory_creator.h | 17 +- onnxruntime/core/session/inference_session.cc | 17 +- .../python/onnxruntime_pybind_schema.cc | 2 +- onnxruntime/test/util/default_providers.cc | 2 +- 26 files changed, 1127 insertions(+), 143 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index dd4ffb835d51..cf3ddc3f125f 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -128,7 +128,7 @@ struct OrtDmlApi { /** * SessionOptionsAppendExecutionProvider_DML2 * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference - * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + * (high power, low power, or default) and a device filter (None, GPU, or NPU). */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 52018500b134..cdb033815756 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -3,6 +3,9 @@ #pragma once interface IMLOperatorRegistry; +interface IDMLDevice; +interface ID3D12CommandQueue; +interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" @@ -28,7 +31,8 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 04381b6ce355..074f13b30918 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -7,11 +7,14 @@ #include #include #include +#include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; +interface IDMLOperator; struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ede3e7f2c225..eb068087de4a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( &protoHelper, executionHandle, true, - &outputShapes, + inputShapesOverrides, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h new file mode 100644 index 000000000000..5ff70493252b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning::Adapter +{ + // edges and unused edges have an empty array of dimensions. + class EdgeShapes + { + public: + EdgeShapes() = default; + + EdgeShapes(size_t count) : m_shapes(count) {} + + const std::vector& GetShape(size_t edgeIndex) const + { + return m_shapes[edgeIndex]; + } + + std::vector& GetMutableShape(size_t edgeIndex) + { + return m_shapes[edgeIndex]; + } + + size_t EdgeCount() const { return m_shapes.size(); } + + void Reset(size_t edge_count) + { + m_shapes.clear(); + m_shapes.resize(edge_count); + } + + bool operator!=(const EdgeShapes& other) const noexcept + { + return (m_shapes != other.m_shapes); + } + + private: + std::vector> m_shapes; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 51b93efb3a64..cd74e7fa9294 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,7 +1,7 @@ #pragma once #include "DmlGraphFusionHelper.h" - +#include "DmlRuntimeFusedGraphKernel.h" namespace Dml { @@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) + { + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(kvp.second.first->raw_data()); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + indexedSubGraph, + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 030cffc2a879..f8f6162aaa1e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 4813707cdf50..679738b639ec 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -15,6 +15,18 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + } + DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -24,15 +36,6 @@ namespace Dml { } - struct CompiledPartitionInfo - { - Microsoft::WRL::ComPtr compiledOperator; - onnxruntime::IndexedSubGraph indexedSubGraph; - std::vector isInputsUploadedByDmlEP; - GraphDescBuilder::GraphDesc graphDesc; - std::unordered_map> isInitializerTransferable; - }; - onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -87,6 +90,7 @@ namespace Dml { // Initializers needed by any graph partition std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; std::unordered_map graphNodePropertyMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( @@ -96,8 +100,10 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, additionalSplittingNodes, - implicitInputDefs); + implicitInputDefs, + false); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); @@ -190,17 +196,48 @@ namespace Dml std::move(graphNodePropertyMap)); // Convert partitionONNXGraph into DML EP GraphDesc + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; + + std::vector subgraphNodes; + subgraphNodes.reserve(indexedSubGraph.nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + for (size_t sortedNodeIndex : indexedSubGraph.nodes) + { + subgraphNodes.push_back(graph.GetNode(sortedNodeIndex)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, - graph, - indexedSubGraph, partitionNodePropsMap, device.Get(), - m_providerImpl); + m_providerImpl, + modelPath, + subgraphNodes, + subgraphInputs, + subgraphOutputs); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp new file mode 100644 index 000000000000..1db22ac92e52 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h" + +using namespace Windows::AI::MachineLearning::Adapter; + +namespace Dml +{ + class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel + { + public: + DmlRuntimeFusedGraphKernel() = delete; + + DmlRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& kernelInfo, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + : OpKernel(kernelInfo), + m_indexedSubGraph(std::move(indexedSubGraph)), + m_modelPath(modelPath), + m_subgraphNodes(std::move(subgraphNodes)), + m_subgraphInputs(std::move(subgraphInputs)), + m_subgraphOutputs(std::move(subgraphOutputs)), + m_intermediateNodeArgs(std::move(intermediateNodeArgs)), + m_partitionNodePropsMap(std::move(partitionNodePropsMap)), + m_ownedInitializers(std::move(ownedInitializers)) + { + for (const auto& initializer : m_ownedInitializers) + { + m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false); + } + + // Get the execution provider interfaces + auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (executionHandle) + { + // We assume the execution object inherits IUnknown as its first base + ComPtr providerExecutionObject = const_cast(static_cast(executionHandle)); + + // Get the WinML-specific execution provider interface from the execution object. + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); + } + + m_subgraphNodePointers.reserve(m_subgraphNodes.size()); + + for (auto& subgraphNode : m_subgraphNodes) + { + m_subgraphNodePointers.push_back(subgraphNode.get()); + } + } + + void TranslateAndCompileGraph( + const onnxruntime::OpKernelInfo& kernelInfo, + std::vector>& initializeResourceRefs, + std::vector initInputBindings) const + { + // Allocate a persistent resource and initialize the operator + UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; + if (persistentResourceSize > 0) + { + ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( + static_cast(persistentResourceSize), + AllocatorRoundingMode::Disabled, + m_persistentResource.ReleaseAndGetAddressOf(), + m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf())); + + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + } + + ORT_THROW_IF_FAILED(m_provider->InitializeOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(initInputBindings))); + + std::for_each( + initializeResourceRefs.begin(), + initializeResourceRefs.end(), + [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } + ); + } + + onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override + { + ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); + + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; + + for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) + { + const auto& input = kernelContext->RequiredInput(inputIndex); + const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); + auto shapeIter = m_inferredInputShapes.find(inputName); + + if (shapeIter == m_inferredInputShapes.end()) + { + m_inferredInputShapes[inputName] = input.Shape(); + recompileNeeded = true; + } + else if (shapeIter->second != input.Shape()) + { + shapeIter->second = input.Shape(); + recompileNeeded = true; + } + + // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list + if (input.Location().device.Type() == OrtDevice::CPU) + { + auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); + m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); + } + } + + if (recompileNeeded) + { + // Go through all the node args and replace their shapes with the real ones + for (auto& nodeArg : m_intermediateNodeArgs) + { + auto iter = m_inferredInputShapes.find(nodeArg->Name()); + if (iter != m_inferredInputShapes.end()) + { + auto tensorShape = *nodeArg->Shape(); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions()); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]); + } + + nodeArg->SetShape(tensorShape); + } + } + + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); + + m_outputShapes = graphDesc.outputShapes; + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } + + // Compile the operator + m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + } + + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) + { + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + } + + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + outputTensors); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + return onnxruntime::Status::OK(); + } + + void ExecuteOperator( + IDMLCompiledOperator* op, + _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, + gsl::span inputTensors, + gsl::span outputTensors) const + { + auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) + { + for (IMLOperatorTensor* tensor : tensors) + { + if (tensor) + { + assert(tensor->IsDataInterface()); + ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + { + for (ID3D12Resource* resource : resources) + { + if (resource) + { + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + std::vector inputBufferBindings; + inputBufferBindings.reserve(inputTensors.size()); + std::vector inputBindings; + inputBindings.reserve(inputTensors.size()); + FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + + std::vector outputBufferBindings; + outputBufferBindings.reserve(outputTensors.size()); + std::vector outputBindings; + outputBindings.reserve(outputTensors.size()); + FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); + + ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( + op, + persistentResourceBinding, + inputBindings, + outputBindings)); + } + + private: + ComPtr m_winmlProvider; + ComPtr m_provider; + + mutable std::optional m_persistentResourceBinding; + std::shared_ptr m_indexedSubGraph; + const onnxruntime::Path& m_modelPath; + + std::vector> m_subgraphNodes; + std::vector m_subgraphInputs; + std::vector m_subgraphOutputs; + mutable std::vector> m_intermediateNodeArgs; + std::unordered_map m_partitionNodePropsMap; + std::vector m_ownedInitializers; + mutable std::unordered_map> m_isInitializerTransferable; + std::vector m_subgraphNodePointers; + + // Bindings from previous executions of a re-used command list + mutable std::vector> m_ownedCpuInputs; + mutable ComPtr m_compiledExecutionPlanOperator; + mutable std::vector m_inputsUsed; + mutable ComPtr m_persistentResource; + mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; + mutable std::unordered_map m_inferredInputShapes; + }; + + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + { + return new DmlRuntimeFusedGraphKernel( + info, + std::move(indexedSubGraph), + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers) + ); + } +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h new file mode 100644 index 000000000000..d679c5aa5667 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "GraphDescBuilder.h" +#include "DmlRuntimeGraphFusionTransformer.h" + +namespace Dml +{ + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers + ); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp new file mode 100644 index 000000000000..6318b0d5e286 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -0,0 +1,161 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "DmlRuntimeGraphFusionTransformer.h" +#include "GraphPartitioner.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" +#include "DmlGraphFusionHelper.h" + +namespace Dml +{ + namespace + { + struct CompiledPartitionInfo + { + std::shared_ptr indexedSubGraph; + std::unordered_map> isInitializerTransferable; + }; + } + + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) + :onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graphLevel, logger, {}); + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const + { + onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernelLookup = onnxruntime::KernelLookup( + providerType, + gsl::make_span(®istry, 1), + kernelTypeStrResolver); + + onnxruntime::GraphViewer graphViewer(graph); + const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); + + for (auto nodeIndex : nodeTopologyList) + { + auto* node = graph.GetNode(nodeIndex); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs)); + } + } + + // Initializers needed by any graph partition + std::vector additionalSplittingNodes; + std::unordered_map graphNodePropertyMap; + std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernelLookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + dynamicCpuInputMap, + additionalSplittingNodes, + implicitInputDefs, + true); + + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + std::vector> compiledPartitionInfos(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) + { + continue; + } + + if (partition->IsDmlGraphPartition()) + { + std::unordered_map> isInitializerTransferable; + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + } + + compiledPartitionInfos[partitionIndex] = std::make_shared(); + compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( + DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); + } + } + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { + DmlGraphFusionHelper::RegisterDynamicKernel( + graph, + m_providerImpl->GetKernelRegistry().get(), + m_providerImpl, + graphNodePropertyMap, + dynamicCpuInputMap, + std::move(compiledPartitionInfo->indexedSubGraph), + std::move(compiledPartitionInfo->isInitializerTransferable)); + } + } + + return onnxruntime::common::Status::OK(); + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h new file mode 100644 index 000000000000..cfa743e1f2b8 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace Dml +{ +class ExecutionProviderImpl; + +class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain"; + +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const final; + + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; + +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 5f6bd178aaa1..8644b8d56a42 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -67,7 +67,8 @@ namespace Dml ExecutionProvider::ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) : + bool enableMetacommands, + bool enableDynamicGraphFusion) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type; @@ -80,7 +81,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion); } std::vector> @@ -147,12 +148,12 @@ namespace Dml // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), - m_areMetacommandsEnabled(enableMetacommands) + m_areMetacommandsEnabled(enableMetacommands), + m_dynamicGraphFusionEnabled(enableDynamicGraphFusion) { - D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { @@ -1093,6 +1094,11 @@ namespace Dml return m_areMetacommandsEnabled; } + bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept + { + return m_dynamicGraphFusionEnabled; + } + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { @@ -1129,9 +1135,10 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) + bool enableMetacommands, + bool enableDynamicGraphFusion) { - return std::make_unique(dmlDevice, commandQueue, enableMetacommands); + return std::make_unique(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f25d..3aaa11cdee47 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,7 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include #include @@ -34,7 +35,8 @@ namespace Dml IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); void ReleaseCompletedReferences(); @@ -150,6 +152,7 @@ namespace Dml STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; + bool DynamicGraphFusionEnabled() const noexcept; std::shared_ptr GetGpuAllocator(); std::shared_ptr GetCpuInputAllocator(); @@ -184,6 +187,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; @@ -236,7 +240,8 @@ namespace Dml explicit ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true + bool enableMetacommands, + bool enableDynamicGraphFusion ); std::unique_ptr GetDataTransfer() const final override @@ -299,9 +304,9 @@ namespace Dml return m_impl.Get(); } - void MetacommandsEnabled() + bool DynamicGraphFusionEnabled() const { - m_impl->MetacommandsEnabled(); + return m_impl->DynamicGraphFusionEnabled(); } virtual std::vector CreatePreferredAllocators() override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 636f46428ce9..3fc8f415e5a5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle) + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs) { - const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; struct NodeAndIndex { uint32_t nodeIndex; // The index of the node itself @@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to the new node and index where it will be produced std::unordered_map nameToNodeAndIndexMap; + std::unordered_map nodeOutputShapes; + // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToDmlFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]); + const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; if (!graphInput) { @@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList) { reuseCommandList = true; } - auto modelPath = graph.ModelPath(); - auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; @@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - for (size_t sortedNodeIndex : indexedSubGraph.nodes) + std::unordered_map> inferredOutputShapes; + + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { - const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); + const onnxruntime::Node& node = *subgraphNode; const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; @@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder return tensor; }; + EdgeShapes inputShapesOverrides(node.InputDefs().size()); + + // Override the input shapes with shapes that were previously inferred + for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex) + { + auto inputDef = node.InputDefs()[inputIndex]; + + auto outputShapesIter = inferredOutputShapes.find(inputDef->Name()); + if (outputShapesIter != inferredOutputShapes.end()) + { + inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; + } + else if (inputDef->HasTensorOrScalarShape()) + { + for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) + { + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value()); + inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast(inputDef->Shape()->dim(i).dim_value())); + } + } + } + + EdgeShapes outputShapes; DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, + &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); + for (int i = 0; i < node.OutputDefs().size(); ++i) + { + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); + } + // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); @@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], operatorGraphOutputEdge.FromNodeOutputIndex }; + + nodeOutputShapes[arg->Name()] = outputShapes; } } @@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder } } + EdgeShapes graphOutputShapes(subgraphOutputs.size()); + // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]); + const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); @@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); graphOutputEdges.push_back(edge); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); @@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder graphDesc.outputEdges = std::move(graphOutputEdges); graphDesc.intermediateEdges = std::move(graphIntermediateEdges); graphDesc.reuseCommandList = reuseCommandList; + graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e5555..0039678c00e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,10 +9,10 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr + std::shared_ptr internalRegInfo; - // These are currently passed from the partitioning step since the only DML operators current + // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing // dimension count. This will change as the supported set of operators as graph nodes increases. Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes; @@ -38,16 +38,19 @@ namespace Dml std::vector outputEdges; std::vector intermediateEdges; bool reuseCommandList; + Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle); + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 18943878cced..f7a4743801d8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,8 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + _Inout_ std::unordered_set& dynamicCpuInputMap, + bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) { @@ -172,36 +174,68 @@ namespace Dml if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (allowDmlGraphDynamicShapes) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - continue; - } + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; + if (graph.GetInitializedTensor(inputName, tensor)) + { + requiredInitializerMap.insert(inputName); + } + else + { + dynamicCpuInputMap.insert(inputName); + } } - requiredInitializerMap.insert(inputName); + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } - - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + else { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + { + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } + + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) + { + requiredCpuInputsConstant = false; + break; + } + + requiredInitializerMap.insert(inputName); + } + + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } } @@ -379,8 +413,10 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs) + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -443,6 +479,8 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, + allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 37d577f647fb..3bddb5ae1608 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -50,6 +50,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs); + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cdec..a8a6d6745e90 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -2,8 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include + #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +interface IDMLCompiledOperator; +struct DML_BUFFER_BINDING; +struct DML_BINDING_DESC; + namespace Dml { struct Binding diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 6cd10e14e08d..4deec620fe5f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::OpNodeProtoHelper* protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), + : OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a7f8bebb2de7..913997ff4ad4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" #include "core/framework/op_kernel.h" #include "core/framework/customregistry.h" #include "core/framework/tensorprotoutils.h" @@ -93,42 +94,6 @@ struct AttributeValue using AttributeMap = std::map; -// Encapsulation of shapes across different edges of an operator. Non-tensor -// edges and unused edges have an empty array of dimensions. -class EdgeShapes -{ -public: - EdgeShapes() = default; - - EdgeShapes(size_t count) : m_shapes(count) {} - - const std::vector& GetShape(size_t edgeIndex) const - { - return m_shapes[edgeIndex]; - } - - std::vector& GetMutableShape(size_t edgeIndex) - { - return m_shapes[edgeIndex]; - } - - size_t EdgeCount() const { return m_shapes.size(); } - - void Reset(size_t edge_count) - { - m_shapes.clear(); - m_shapes.resize(edge_count); - } - - bool operator!=(const EdgeShapes& other) const noexcept - { - return (m_shapes != other.m_shapes); - } - - private: - std::vector> m_shapes; -}; - // Base class for ABI objects which may be "Closed", at which point calls will predictably // fail or return a dummy value. This is used for transient ABI context objects which // are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes @@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< const onnxruntime::OpNodeProtoHelper * protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index cd8bc8fe909d..d587424fe01f 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -30,8 +30,12 @@ namespace onnxruntime { struct DMLProviderFactory : IExecutionProviderFactory { DMLProviderFactory(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) : dml_device_(dml_device), - cmd_queue_(cmd_queue) {} + ID3D12CommandQueue* cmd_queue, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) : dml_device_(dml_device), + cmd_queue_(cmd_queue), + metacommands_enabled_(!disable_metacommands), + dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {} ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; @@ -42,10 +46,11 @@ struct DMLProviderFactory : IExecutionProviderFactory { ComPtr dml_device_{}; ComPtr cmd_queue_{}; bool metacommands_enabled_ = true; + bool dynamic_graph_fusion_enabled_ = false; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_, dynamic_graph_fusion_enabled_); return provider; } @@ -54,7 +59,9 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) { } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) { + ID3D12CommandQueue* cmd_queue, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { #ifndef _GAMING_XBOX // Validate that the D3D12 devices match between DML and the command queue. This specifically asks for IUnknown in // order to be able to compare the pointers for COM object identity. @@ -73,7 +80,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID const Env& env = Env::Default(); auto luid = d3d12_device->GetAdapterLuid(); env.GetTelemetryProvider().LogExecutionProviderEvent(&luid); - return std::make_shared(dml_device, cmd_queue); + return std::make_shared(dml_device, cmd_queue, disable_metacommands, enable_dynamic_graph_fusion); } void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) { @@ -234,12 +241,10 @@ static void SortHeterogenousDXCoreAdapterList( std::sort(adapter_infos.begin(), adapter_infos.end(), policy); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { - return Create(device_id, /*skip_software_device_check*/ false); -} - std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( - OrtDmlDeviceOptions* device_options) { + OrtDmlDeviceOptions* device_options, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { auto default_device_options = OrtDmlDeviceOptions { Default, Gpu }; if (device_options == nullptr) { device_options = &default_device_options; @@ -286,7 +291,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom adapters.begin(), [](auto& a){ return a.Adapter; }); - return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters), disable_metacommands, enable_dynamic_graph_fusion); } static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) { @@ -354,12 +359,32 @@ static std::optional ParseDeviceId(const ProviderOptions& provider_options) return {}; } +static bool ParseBoolean(const ProviderOptions& provider_options, const std::string& key) { + auto preference_it = provider_options.find(key); + if (preference_it != provider_options.end() && !preference_it->second.empty()) { + if (preference_it->second == "True" || preference_it->second == "true") { + return true; + } else if (preference_it->second == "False" || preference_it->second == "false") { + return false; + } else { + ORT_THROW("[ERROR] [DirectML] The value for the key '" + key + "' should be 'True' or 'False'. Default value is 'False'.\n"); + } + } + + return false; +} + std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( - const ProviderOptions& provider_options) { + const ProviderOptions& provider_options) { + + bool disable_metacommands = ParseBoolean(provider_options, "disable_metacommands"); + bool enable_dynamic_graph_fusion = ParseBoolean(provider_options, "enable_dynamic_graph_fusion"); + bool skip_software_device_check = false; auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) { - return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value(), skip_software_device_check, disable_metacommands, enable_dynamic_graph_fusion); } auto preference = ParsePerformancePreference(provider_options); @@ -367,7 +392,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom // If no preference/filters are specified then create with default preference/filters. if (!preference.has_value() && !filter.has_value()) { - return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr, disable_metacommands, enable_dynamic_graph_fusion); } if (!preference.has_value()) { @@ -381,7 +406,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom OrtDmlDeviceOptions device_options; device_options.Preference = preference.value(); device_options.Filter = filter.value(); - return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options, disable_metacommands, enable_dynamic_graph_fusion); } Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device( @@ -441,7 +466,10 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } -std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) { +std::shared_ptr CreateDMLDeviceAndProviderFactory( + ID3D12Device* d3d12_device, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; @@ -450,16 +478,22 @@ std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3 ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); + return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get(), disable_metacommands, enable_dynamic_graph_fusion); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { +std::shared_ptr DMLProviderFactoryCreator::Create( + int device_id, + bool skip_software_device_check, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); - return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); } std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( - std::vector>&& dxcore_devices) { + std::vector>&& dxcore_devices, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { // Choose the first device from the list since it's the highest priority auto dxcore_device = dxcore_devices[0]; @@ -467,7 +501,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom ComPtr d3d12_device; ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); - return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); } } // namespace onnxruntime @@ -477,7 +511,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom // The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead. ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id)); + options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id, false, false, false)); API_IMPL_END return nullptr; } @@ -489,7 +523,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSess _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) { API_IMPL_BEGIN options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(dml_device, - cmd_queue)); + cmd_queue, + false, + false)); API_IMPL_END return nullptr; } @@ -517,7 +553,7 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { API_IMPL_BEGIN #ifdef USE_DML - auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options, false, false); // return the create function for a dxcore device options->provider_factories.push_back(factory); #endif // USE_DML diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 4e13330a4cd7..0fab9fe90252 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -17,15 +17,24 @@ namespace onnxruntime { struct DMLProviderFactoryCreator { - static std::shared_ptr Create(int device_id); - static std::shared_ptr Create(int device_id, bool skip_software_device_check); + static std::shared_ptr Create( + int device_id, + bool skip_software_device_check, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); static std::shared_ptr CreateFromProviderOptions( const ProviderOptions& provider_options_map); - static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromOptions( + OrtDmlDeviceOptions* device_options, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); static std::shared_ptr CreateFromAdapterList( - std::vector>&& dxcore_devices); + std::vector>&& dxcore_devices, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 077b10ffc552..1163be27b168 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -52,8 +52,10 @@ #include "core/providers/cpu/cpu_execution_provider.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h" #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #include "core/providers/dml/dml_session_options_config_keys.h" +#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h" #endif #include "core/session/environment.h" #include "core/session/user_logging_sink.h" @@ -1598,7 +1600,9 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema)); #ifdef USE_DML - if (execution_providers_.Get(kDmlExecutionProvider)) { + const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); + + if (dmlExecutionProvider) { // DML graph fusion is an important runtime optimization that cannot be done ahead of time; it must be disabled // when running in "offline mode" and saving an optimized model to disk. To support users that want to optimize // models offline, and then disable graph optimizations when running "online", this transformer ignores the ORT @@ -1608,11 +1612,20 @@ common::Status InferenceSession::Initialize() { if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - execution_providers_.Get(kDmlExecutionProvider)); + dmlExecutionProvider); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + + if (static_cast(dmlExecutionProvider)->DynamicGraphFusionEnabled()) { + std::unique_ptr dmlRuntimeGraphFusionTransformer = std::make_unique("DmlRuntimeGraphFusionTransformer", + dmlExecutionProvider); + if (dmlRuntimeGraphFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + } } // This transformer applies DML-specific fusions that go beyond what ORT offers by default diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index a8c217b0ff1f..3a977772873f 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -59,7 +59,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::ArmNNProviderFactoryCreator::Create(0), #endif #ifdef USE_DML - onnxruntime::DMLProviderFactoryCreator::Create(0, /*skip_software_device_check*/ true), + onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false), #endif #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 28af61e15b2b..e224507bc740 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -268,7 +268,7 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML - if (auto factory = DMLProviderFactoryCreator::Create(0)) + if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false)) return factory->CreateProvider(); #endif return nullptr; From 0f72739b6db129373d221483d61d6637ec11fb28 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 26 Oct 2023 04:03:01 -0700 Subject: [PATCH 12/36] Disable ccache for WinML build (#18104) ### Description It seems would resolve the timeout issue. ### Motivation and Context --- tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 2a5622faf290..ed010b5619db 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -132,7 +132,7 @@ stages: isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false - WITH_CACHE: true + WITH_CACHE: false MachinePool: 'onnxruntime-Win-CPU-2022' - stage: x86_release From 64de71c5e2af8ee67ce79a9bb8d7244a8238064c Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Thu, 26 Oct 2023 09:22:10 -0700 Subject: [PATCH 13/36] [js/web/training] Add CreateTrainingSession (#17891) ### Description * Adds TrainingSession.create() functionality following the web bindings for training design doc * Added 2 new training APIs to wasm/api.h: * OrtTrainingGetInputOutputName * OrtTrainingGetInputOutputCount * Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a method that references it ### Motivation and Context * Adding web bindings for training #### Related work * #16521 allowed for training artifacts to be built * #17333 added interfaces for training * #17474 allows for training package to be built + adds training backend to web package **[MUST BE MERGED IN BEFORE THIS ONE]** --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Ashwini Khade --- js/common/lib/training-session-impl.ts | 21 ++- js/web/lib/backend-wasm-training.ts | 12 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 5 + js/web/lib/wasm/proxy-messages.ts | 7 +- js/web/lib/wasm/proxy-worker/main.ts | 10 +- js/web/lib/wasm/proxy-wrapper.ts | 21 +++ .../lib/wasm/session-handler-for-training.ts | 73 ++++++++ js/web/lib/wasm/session-handler.ts | 6 +- js/web/lib/wasm/wasm-core-impl.ts | 6 + js/web/lib/wasm/wasm-training-core-impl.ts | 162 ++++++++++++++++++ onnxruntime/wasm/api.cc | 59 +++++++ onnxruntime/wasm/api.h | 29 ++++ 12 files changed, 399 insertions(+), 12 deletions(-) create mode 100644 js/web/lib/wasm/session-handler-for-training.ts create mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index f06d06bda035..47e67879e66c 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { - throw new Error('Method not implemented'); + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { + throw new Error(noBackendErrMsg); + } } async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index af5b575c87a7..98e40807aa29 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,13 +4,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, - _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, - _options: InferenceSession.SessionOptions): Promise { - throw new Error('Method not implemented yet.'); + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); } } diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index b7b2ff453709..00431a4e86d5 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetModelInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 7aa866773bcb..efeb086256cf 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index fe8bd9b11b19..1f4595818e5c 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -4,7 +4,7 @@ /// import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index a3e4a1ef1fc7..069a1fa452db 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = []; const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } + break; default: } }; @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +}; diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts new file mode 100644 index 000000000000..83d133b9a515 --- /dev/null +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata} from './proxy-messages'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } + + async runTrainStep( + _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, + _options: InferenceSession.RunOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d1760e37c93f..a5017a920f38 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5b49a1d4202e..947242945c66 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; }; /** @@ -93,6 +97,8 @@ type SessionMetadata = [ const activeSessions = new Map(); +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; + /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. * @returns a 2-elements tuple - the pointer and size of the allocated buffer diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 000000000000..4830b5d2b5e8 --- /dev/null +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; +import {setSessionOptions} from './session-options'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; + +const NO_TRAIN_FUNCS_MSG = + 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + const [checkpointDataOffset, checkpointDataLength] = checkpointData; + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (checkpointHandle === 0) { + checkLastError('Error occurred when trying to create a CheckpointState.'); + } + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetModelInputOutputCount) { + const errorCode = + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getModelInputOutputNamesLoop = + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; + }; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); + + const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); + const [outputNames, outputNamesUTF8Encoded] = + getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (trainingSessionHandle === 0) { + checkLastError('Error occurred when trying to create a TrainingSession.'); + } + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + }; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 968eece36172..0e58bb4f93f7 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -493,6 +493,14 @@ char* OrtEndProfiling(ort_session_handle_t session) { #define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \ CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__)) +#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ + do { \ + int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \ + if (error_code != ORT_OK) { \ + return error_code; \ + } \ + } while (false) + ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size) { OrtCheckpointState* checkpoint_state = nullptr; @@ -571,6 +579,57 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count, + bool isEvalModel) { + if (isEvalModel) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count); + return ORT_OK; + } else { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; + } +} + +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput, + bool isEvalModel) { + OrtAllocator* allocator = nullptr; + RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + char* name = nullptr; + + if (isEvalModel) { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } + } else { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } + } +} + void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 9a0664697f0f..2cd1515d191c 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -432,6 +432,35 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio size_t parameter_count, bool trainable_only); +/** + * Gets the input count and output count of the training or eval model associated with the given training handle. + * @param traning_handle handle of the traning session + * @param input_count [out] a pointer to a size_t variable to accept input_count + * @param output_count [out] a pointer to a size_t variable to accept output_count + * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output + * count of the eval model. + * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count, + bool isEvalModel); + +/** + * Gets the input or output name at the specified index associated with the training or eval model from the + * given training session. + * @param traning_handle handle of the traning session + * @param index the input or output index + * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output + * names of the eval model. + * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by + */ +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput, + bool isEvalModel); + /** * @brief Release the specified ORT training session. * From b023de0bfc7acb2404dfdcc4adc060b7b72fdaa1 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 26 Oct 2023 10:12:46 -0700 Subject: [PATCH 14/36] Redo #18044 Install CUDA 12.2 on Windows (#18093) --- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 129dbc833a0a..3696c41c196d 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -376,7 +376,7 @@ stages: - task: BatchScript@1 displayName: 'setup env' inputs: - filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env_cuda_11.bat' + filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env_cuda.bat' modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' From 455a9ce61418cca1674fd20222c6bf4c6bf4e3b4 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 26 Oct 2023 20:55:12 +0000 Subject: [PATCH 15/36] [TensorRT EP] Use latest onnx-tensorrt parser (#18067) Use latest onnx-tensorrt to fix compile error. Please see the issue https://github.com/microsoft/onnxruntime/issues/18029 --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 2 +- .../github/azure-pipelines/templates/download-deps.yml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f9f2fbdab7b1..f9501253661a 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -206,7 +206,7 @@ "component": { "type": "git", "git": { - "commitHash": "0462dc31ae78f48744b6141ae376df1f96d3f459", + "commitHash": "a43ce67187bab219520fd80f21af8bbd4354bc8c", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" diff --git a/cmake/deps.txt b/cmake/deps.txt index 26fd35075c4b..631d326e2ba5 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -26,7 +26,7 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 9ca4a45ffcec..1373381e4c83 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.95 + version: 1.0.97 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.95 + version: 1.0.97 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From a514a6877015bddaa6b98e272720ac9927629e14 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 14:47:16 -0700 Subject: [PATCH 16/36] Support per-tensor device mesh at op level (#18025) Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1], we can't assume same device mesh per op. At each operator, we replace a single operator-level device mesh - `device_mesh_shapes` - `device_mesh_elements` with per-tensor device meshes - `input_device_mesh_shapes` (input_device_mesh_shapes[i] is the device mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3 devices) - `input_device_mesh_elements` (input_device_mesh_elements[i] is the flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if you have 3 devices in that mesh) - `output_device_mesh_shapes` - `output_device_mesh_elements` Check out the change in onnxruntime_test_distributed.py for examples. It's also heavily used in #18068's `onnxruntime_test_distributed.py` change. --- .../contrib_ops/cuda/collective/sharding.cc | 49 +++++++++- .../cuda/collective/sharding_spec.cc | 22 +++++ .../cuda/collective/sharding_spec.h | 3 + .../core/graph/contrib_ops/collective_defs.cc | 50 ++++++---- .../python/onnxruntime_test_distributed.py | 92 +++++++++++++++---- 5 files changed, 176 insertions(+), 40 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index 7d106fd75e2d..dfd5f589355d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -5,6 +5,8 @@ #include "mpi_include.h" #include "sharding_spec.h" +#include +#include #include "core/providers/cpu/tensor/slice.h" #include "core/providers/cuda/tensor/slice.h" #include "core/providers/cuda/math/matmul.h" @@ -237,16 +239,55 @@ void ReshardTensor( } DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) { - std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); - std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); - std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); - std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); + // input_device_mesh_shapes[i] is the shape of device mesh for the i-th input. + // E.g., device_mesh_shapes = ["[2]", "[1]"] means the first input is + // stored on a 1-D mesh with 2 devices and the second input on another 1-D + // mesh with 1 device. + std::vector attr_input_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK()); + // input_device_mesh_elements[i] is the flattened device mesh for the i-th input. + // Note that its actual shape is input_device_mesh_shapes[i]. + // Example: + // Assume + // device_mesh_shapes = ["[2]", "[1]"] + // device_mesh_elements = ["[0,1]", "[0]"] + // Then the first input is stored on a 1-D mesh with 2 devices and the second + // input on another 1-D mesh with 1 device. + std::vector attr_input_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK()); + + // input_shard_specs[i] is the sharding spec of the i-th input; e.g., + // "RR" if the i-th input is not sharded. + std::vector input_shard_specs; + ORT_ENFORCE(info.GetAttrs("input_shard_specs", input_shard_specs).IsOK()); + + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size()); + ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size()); + + // Begin parsing sharding metadata for inputs. for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_input_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_input_device_mesh_elements[i]); auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); input_shard_specs_.push_back(spec); } + + std::vector attr_output_device_mesh_shapes; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK()); + + std::vector attr_output_device_mesh_elements; + ORT_ENFORCE(info.GetAttrs("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK()); + + std::vector output_shard_specs; + ORT_ENFORCE(info.GetAttrs("output_shard_specs", output_shard_specs).IsOK()); + + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size()); + ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size()); + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto device_mesh_shape = ParseStringAsInt64Vector(attr_output_device_mesh_shapes[i]); + auto device_mesh_elements = ParseStringAsInt64Vector(attr_output_device_mesh_elements[i]); auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); output_shard_specs_.push_back(spec); } diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index f1d399077e37..220938f3ceae 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -27,6 +27,28 @@ void ValidateAxisIndex(const int64_t axis, const int64_t rank) { ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ")."); } +std::vector ParseStringAsInt64Vector(const std::string& str) { + if (str.empty() || str.front() != '[' || str.back() != ']') { + throw std::invalid_argument("Invalid input string format"); + } + // Parsed vector. + // If input is "[0, 1, 2]", result should be {0, 1, 2}. + std::vector result; + // Skip '[' and ']' + std::istringstream iss(str.substr(1, str.size() - 2)); + + // Extract integers separated by ',' or whitespaces. + int64_t num = -1; + while (/* Read one number at a time */ iss >> num) { + result.push_back(num); + // Skip the comma + if (iss.peek() == ',') { + iss.ignore(); + } + } + return result; +} + DeviceMesh CreateDeviceMesh( std::vector device_mesh_shape, std::vector device_mesh_elements) { diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 0f5ef6927a54..451d44b4bd43 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -314,6 +314,9 @@ class TensorPartitionSpec { } }; +// Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. +std::vector ParseStringAsInt64Vector(const std::string& str); + DeviceMesh CreateDeviceMesh( std::vector device_mesh_shape, std::vector device_mesh_elements); diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 7cdd71014c02..97befe2a5830 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -83,17 +83,26 @@ void RegisterCollectiveOps() { ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) .SetDomain(kMSDomain) .SinceVersion(1) - .Attr("device_mesh_elements", - "", - AttributeProto::INTS) - .Attr("device_mesh_shape", - "", - AttributeProto::INTS) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) .Attr("input_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", AttributeProto::STRINGS) .Attr("output_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "Similar to input_shard_specs but for outputs.", AttributeProto::STRINGS) .Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) @@ -109,17 +118,26 @@ void RegisterCollectiveOps() { ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice) .SetDomain(kMSDomain) .SinceVersion(1) - .Attr("device_mesh_elements", - "", - AttributeProto::INTS) - .Attr("device_mesh_shape", - "", - AttributeProto::INTS) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) .Attr("input_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", AttributeProto::STRINGS) .Attr("output_shard_specs", - "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + "Similar to input_shard_specs but for outputs.", AttributeProto::STRINGS) .Input( 0, diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index 1baec80cb7c4..a9b55122c680 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -20,15 +20,22 @@ def shard_tensor(X, rank, axis, num_shards): class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): + # It means 1-D tensor with single element: [2]. + device_mesh_shape = "[2]" + # It means 1-D tensor with two elements: [0, 1]. + device_mesh_elements = "[0,1]" + @onnxscript.script() def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "S[0]R"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -55,15 +62,20 @@ def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul2d_rs_rs_rr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_rs_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "RS[0]"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -93,15 +105,20 @@ def matmul_rs_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul2d_rs_rs_rs(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul2d_rs_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "RS[0]"], output_shard_specs=["RS[0]"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -128,15 +145,20 @@ def matmul2d_rs_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_srr_rr_srr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_srr_rr_srr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]RR", "RR"], output_shard_specs=["S[0]RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -165,15 +187,20 @@ def matmul_srr_rr_srr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_srr_rrrr_rsrr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_srr_rrrr_rsrr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]RR", "RRRR"], output_shard_specs=["RS[0]RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -202,15 +229,20 @@ def matmul_srr_rrrr_rsrr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_sr_rs_rr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_sr_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]R", "RS[0]"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -239,15 +271,20 @@ def matmul_sr_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_rr_rs_rs(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_rr_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RR", "RS[0]"], output_shard_specs=["RS[0]"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -276,15 +313,20 @@ def matmul_rr_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_matmul_rr_sr_rr(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RR", "S[0]R"], output_shard_specs=["RR"], + input_device_mesh_shapes=[device_mesh_shape, device_mesh_shape], + input_device_mesh_elements=[device_mesh_elements, device_mesh_elements], + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -313,6 +355,9 @@ def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_slice_sr_axis1(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: return MICROSOFT_OPSET.DistributedSlice( @@ -320,10 +365,12 @@ def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, te tensor_starts, tensor_ends, tensor_axes, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["S[0]R", "R", "R", "R", "R"], output_shard_specs=["S[0]R"], + input_device_mesh_shapes=[device_mesh_shape] * 5, + input_device_mesh_elements=[device_mesh_elements] * 5, + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() @@ -360,6 +407,9 @@ def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, te np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) def test_slice_rs_axis1(self): + device_mesh_shape = "[2]" + device_mesh_elements = "[0, 1]" + @onnxscript.script() def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: return MICROSOFT_OPSET.DistributedSlice( @@ -367,10 +417,12 @@ def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, te tensor_starts, tensor_ends, tensor_axes, - device_mesh_shape=[2], - device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "R", "R", "R", "R"], output_shard_specs=["RS[0]"], + input_device_mesh_shapes=[device_mesh_shape] * 5, + input_device_mesh_elements=[device_mesh_elements] * 5, + output_device_mesh_shapes=[device_mesh_shape], + output_device_mesh_elements=[device_mesh_elements], ) rank = comm.Get_rank() From f2e19a8ccf90e089b47b01ac59a83e6991d53776 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Thu, 26 Oct 2023 14:58:57 -0700 Subject: [PATCH 17/36] Updates to training pipelines to reduce CI time (#18116) ### Description Motivation for this PR is reducing CI test time by removing unnecessary tests from the pipelines. Following changes are for reducing test time in pipelines: - Skip CPU model tests in GPU builds. Training CIs run these tests as a sanity check. There is no direct training code being tested in these pipelines, furthermore, CPU tests are being run in CPU pipelines so no need to run them again in GPU builds and block the GPU VM. This change reduces testing time by 20-25 mins in all training GPU pipelines. - Delete debug package building pipeline for linux training packages. This was required by compiler team at some point but there have been 0 downloads of these packages. ### Motivation and Context --- onnxruntime/test/providers/cpu/model_tests.cc | 7 +++++++ .../orttraining-py-packaging-pipeline-cpu.yml | 2 +- .../orttraining-py-packaging-pipeline-cuda.yml | 13 ------------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 13dcded6f3b8..c2e7577a7ca5 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -443,6 +443,13 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #ifdef USE_DML provider_names[provider_name_dml] = {opset7, opset8, opset9, opset10, opset11, opset12, opset13, opset14, opset15, opset16, opset17, opset18}; #endif + +#if defined(ENABLE_TRAINING_CORE) && defined(USE_CUDA) + // Removing the CPU EP tests from CUDA build for training as these tests are already run in the CPU pipelines. + // Note: These are inference tests, we run these in training builds as an extra check. Therefore reducing + // the number of times these are run to reduce the CI time. + provider_names.erase(provider_name_cpu); +#endif std::vector> v; // Permanently exclude following tests because ORT support only opset starting from 7, // Please make no more changes to the list diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 983143df3f04..9755e1f0771b 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -9,7 +9,7 @@ resources: ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: -- stage: Python_Packaging_Linux_Trainin_CPU +- stage: Python_Packaging_Linux_Training_CPU jobs: - job: Linux_Training_CPU_Wheels diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index b8dfb7f3c90a..f244851f8cc3 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -20,16 +20,3 @@ stages: agent_pool: Onnxruntime-Linux-GPU upload_wheel: 'yes' debug_build: false - -# Added for triton compiler team. Can be potentially removed. -- template: templates/py-packaging-training-cuda-stage.yml - parameters: - build_py_parameters: --enable_training --update --build - torch_version: '2.0.0' - opset_version: '15' - cuda_version: '11.8' - cmake_cuda_architectures: 70;75;80;86 - docker_file: Dockerfile.manylinux2_28_training_cuda11_8 - agent_pool: Onnxruntime-Linux-GPU - upload_wheel: 'no' - debug_build: true From 7c18c60bc20dcf2006b9e546e3162bdc7e18823c Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 26 Oct 2023 16:28:57 -0700 Subject: [PATCH 18/36] Change cuda image for tensorRT to the one with cudnn8 (#18102) ### Description copilot:summary ### Motivation and Context copliot::walkthrough --- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 +--- .../linux-gpu-tensorrt-ci-pipeline.yml | 1 - .../linux-gpu-tensorrt-packaging-pipeline.yml | 3 +-- .../azure-pipelines/templates/py-linux-gpu.yml | 3 +-- .../templates/py-packaging-linux-test-cuda.yml | 3 +-- .../linux/docker/Dockerfile.manylinux2_28_cuda | 14 ++------------ 7 files changed, 7 insertions(+), 23 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3696c41c196d..14a9bbedf09a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -490,7 +490,7 @@ stages: tools/ci_build/get_docker_image.py \ --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ --context tools/ci_build/github/linux/docker \ - --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 --build-arg INSTALL_CUDNN=true --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ + --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ --container-registry onnxruntimebuildcache \ --multiple_repos \ --repository onnxruntimecuda118xtrt86build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1d4681d06438..9e1fae343c84 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -62,9 +62,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda11build @@ -166,7 +165,6 @@ jobs: --network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda11build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 16d4457c45eb..517c8d638c93 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -63,7 +63,6 @@ jobs: --network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimetensorrt86gpubuild diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 0d58f6cee400..85562d7758ab 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -48,9 +48,8 @@ stages: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda118xtrt86build diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index 33c82b5e8965..f68847afff37 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -40,9 +40,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index a70e0c01e52f..5dad3ad1f59a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -85,9 +85,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 - --build-arg INSTALL_CUDNN=true --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 7b2cada73648..d4aa9b269095 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -1,14 +1,13 @@ -# The default ARGs are for cuda 12.2 with TRT v = 8.6.1.6-1.cuda12.0 +# The default ARGs are for cuda 11.8 with cudnn8,TensorRT is optional # Please overwirete BASEIMAGE, TRT_VERSION and other arguments with # --docker-build-args ' --build-arg BASEIMAGE=other_base_image --build-arg TRT_VERSION=other_trt_version etc...' # for other cuda version and TRT version ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 -ARG BASEIMAGE=nvidia/cuda:12.2.0-devel-ubi8 +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 ARG PREPEND_PATH=/usr/local/cuda/binet -ARG INSTALL_CUDNN=false #Build manylinux docker image begin FROM $BASEIMAGE AS runtime_base @@ -154,15 +153,6 @@ CMD ["/bin/bash"] #Build manylinux docker image end - -#Install optinal Cudnn -RUN if [ "$INSTALL_CUDNN" = true ]; then \ - CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') && \ - dnf -y install \ - libcudnn8-devel-*cuda${CUDA_VERSION}* \ - libcudnn8-*cuda${CUDA_VERSION}* ; \ -fi - #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "$TRT_VERSION" ]; then \ echo "TRT_VERSION is $TRT_VERSION" && \ From b7bee621cd983926ed801b420c253cd26ea8e6ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Fri, 27 Oct 2023 01:32:01 +0200 Subject: [PATCH 19/36] [CUDA] Remove shape warnings in NHWC <> NCHW unit tests (#17992) There were some warning in https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1170770 e.g. ``` [ RUN ] CudaNhwcTypedTest/1.AveragePoolNhwcPad @ /home/administrator/onnxruntime/onnxruntime/test/providers/cuda/nhwc/pool_test.cc:84 [W:onnxruntime:Default, graph.cc:108 MergeShapeInfo] Error merging shape info for output. 'Y' source:{1,16,66,66} target:{1,16,67,67}. Falling back to lenient merge. ``` These warnings where not specific to NHWC or NCHW but were just a miscalculation of output shape in some tests. --- .../test/providers/cuda/nhwc/conv_transpose_test.cc | 8 ++++++-- onnxruntime/test/providers/cuda/nhwc/pool_test.cc | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index d45323190c51..06da2a530471 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -42,12 +42,16 @@ struct ConvTransposeOp { test->AddAttribute("pads", padding); if (!output_padding.empty()) { test->AddAttribute("output_padding", output_padding); + } else { + output_padding = {0, 0, 0, 0}; } std::vector output_dims = { input_dims[0], channels, - (kernel_shape[1] - 1) * dilations[1] + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1, - (kernel_shape[0] - 1) * dilations[0] + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1}; + (kernel_shape[1] - 1) * dilations[1] + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1 + + output_padding[2], + (kernel_shape[0] - 1) * dilations[0] + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1 + + output_padding[3]}; std::vector output_data = FillZeros(output_dims); test->AddOutput("Y", output_dims, output_data); diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index 3d1f81e6bc28..e0d59901da80 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -31,8 +31,8 @@ struct PoolOp { std::vector output_dims = { input_dims[0], channels, - (kernel_shape[1] - 1) + (input_dims[2] - 1) * strides[1] - (padding[1] + padding[0]) + 1, - (kernel_shape[0] - 1) + (input_dims[3] - 1) * strides[0] - (padding[3] + padding[2]) + 1}; + (input_dims[2] - (kernel_shape[0] - 1) + padding[1] + padding[0] - 1) / strides[0] + 1, + (input_dims[3] - (kernel_shape[1] - 1) + padding[3] + padding[2] - 1) / strides[1] + 1}; std::vector output_data = FillZeros(output_dims); test->AddOutput("Y", output_dims, output_data); From 52f496835949537fe359a544ce14ec4810860b6d Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Fri, 27 Oct 2023 07:33:03 +0800 Subject: [PATCH 20/36] [js/webgpu] Change timestamp-query-in-passes to timestamp-query (#18108) Timestamp-query has a broader support than timestamp-query-in-passes on all the platforms, including macOS. Note that to enable timestamp-query, you still need to add switch "--enable-dawn-features=allow_unsafe_apis" to Chrome. By default, the lowest 16 bits are masked with 0 (at a granularity about 0.1ms) for privacy. To get the highest precision, you need to add another switch "--enable-webgpu-developer-features". --- js/web/lib/wasm/jsep/backend-webgpu.ts | 60 ++++++++++++------- .../lib/wasm/jsep/webgpu/program-manager.ts | 38 ++++-------- 2 files changed, 50 insertions(+), 48 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5d66caf77f08..eb40da048835 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -126,14 +126,14 @@ export class WebGpuBackend { */ kernels: Map unknown) | undefined, unknown]]>; - commandEncoder: GPUCommandEncoder|null = null; - computePassEncoder: GPUComputePassEncoder|null = null; + private commandEncoder: GPUCommandEncoder|null = null; + private computePassEncoder: GPUComputePassEncoder|null = null; pendingDispatchNumber = 0; - supportTimestampQuery = false; - profilingQuerySet: GPUQuerySet; - profilingQueryData: GpuData; - profilingTimeBase?: bigint; + queryData?: GpuData; + querySet?: GPUQuerySet; + querySetCount = 2; + queryTimeBase?: bigint; env: Env; @@ -168,11 +168,9 @@ export class WebGpuBackend { }, requiredFeatures, }; - // WebGPU Spec: Timestamp Queries Inside Passes - // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md - if (adapter.features.has('timestamp-query-inside-passes')) { - this.supportTimestampQuery = true; - requiredFeatures.push('timestamp-query-inside-passes' as GPUFeatureName); + + if (adapter.features.has('timestamp-query')) { + requiredFeatures.push('timestamp-query'); } if (adapter.features.has('shader-f16')) { requiredFeatures.push('shader-f16'); @@ -197,21 +195,14 @@ export class WebGpuBackend { } }; - if (this.supportTimestampQuery) { - this.profilingQuerySet = this.device.createQuerySet({ - type: 'timestamp', - count: 2, - }); - } - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); } dispose(): void { - // currently, we do not do anything in this function. In all known use cases, we don't have the requirement to - // actually dispose the WebGpuBackend instance, because it's always used as a singleton. - // - // revisit this place if we get real requirement to dispose the instance. + if (typeof this.querySet !== 'undefined') { + this.querySet.destroy(); + } + this.gpuDataManager.dispose(); } getCommandEncoder(): GPUCommandEncoder { @@ -223,7 +214,22 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { - this.computePassEncoder = this.getCommandEncoder().beginComputePass(); + const computePassDescriptor: GPUComputePassDescriptor = {}; + if (this.isQueryEnabled()) { + if (typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.querySetCount, + }); + } + computePassDescriptor.timestampWrites = { + querySet: this.querySet, + beginningOfPassWriteIndex: 0, + endOfPassWriteIndex: 1, + }; + } + + this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -245,6 +251,14 @@ export class WebGpuBackend { } } + isQueryEnabled(): boolean { + if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') { + return true; + } else { + return false; + } + } + /** * run a WebGPU program. * @param program a ProgramInfo instance diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 5c5a07d90d34..341e6edf26cc 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,14 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); - const profilingEnabled = this.backend.supportTimestampQuery && this.backend.env.webgpu.profilingMode === 'default'; - if (profilingEnabled) { - // profiling write start timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 0); - } - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { @@ -65,24 +57,20 @@ export class ProgramManager { this.backend.pendingDispatchNumber++; - if (profilingEnabled) { - // profiling write end timestamp - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 1); - if (this.backend.profilingQueryData == null) { - this.backend.profilingQueryData = + if (this.backend.isQueryEnabled()) { + if (typeof this.backend.queryData === 'undefined') { + this.backend.queryData = this.backend.gpuDataManager.create( // eslint-disable-next-line no-bitwise - this.backend.gpuDataManager.create(16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); + this.backend.querySetCount * 8, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); } - // eslint-disable-next-line no-bitwise - const syncData = this.backend.gpuDataManager.create(16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + const syncData = this.backend.gpuDataManager.create( + // eslint-disable-next-line no-bitwise + this.backend.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); this.backend.endComputePass(); - this.backend.getCommandEncoder().resolveQuerySet( - this.backend.profilingQuerySet, 0, 2, this.backend.profilingQueryData.buffer, 0); + this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet, 0, 2, this.backend.queryData.buffer, 0); this.backend.getCommandEncoder().copyBufferToBuffer( - this.backend.profilingQueryData.buffer, 0, syncData.buffer, 0, 16); + this.backend.queryData.buffer, 0, syncData.buffer, 0, this.backend.querySetCount * 8); this.backend.flush(); const kernelId = this.backend.currentKernelId!; @@ -96,12 +84,12 @@ export class ProgramManager { syncData.buffer.unmap(); - if (typeof this.backend.profilingTimeBase === 'undefined') { - this.backend.profilingTimeBase = startTimeU64; + if (typeof this.backend.queryTimeBase === 'undefined') { + this.backend.queryTimeBase = startTimeU64; } - const startTime = Number(startTimeU64 - this.backend.profilingTimeBase); - const endTime = Number(endTimeU64 - this.backend.profilingTimeBase); + const startTime = Number(startTimeU64 - this.backend.queryTimeBase); + const endTime = Number(endTimeU64 - this.backend.queryTimeBase); if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { throw new RangeError('incorrect timestamp range'); From 37873be86df6715256b79818c1e7759848bf5c35 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Thu, 26 Oct 2023 16:57:21 -0700 Subject: [PATCH 21/36] enable reduce ops on opset18 (#18053) ### Description Opset 18 apply the "axes as input" change from ReduceSum to all the other reduce ops. Our cuda kernel actually support it, but we didn't enable it for opset18. This PR update the reduce ops' kernel registration to enable the "axes as input" behavior for opset18. As part of the fix, I also simplify the reduce op kernel registration part. ORT doesn't require the kernel definition need to be exactly the same as onnx op definition. For our case, which we share the same kernel for all the reduce ops (from version 1 to version 18), we don't need to maintain different version of kernel definitions. we can simplify it by just using a single kernel definition for multiple versions. Although for some cases, we might register more types for legacy versions, but it is harmless. Framework is using schema to validate the graph, not kernel definition. --------- Co-authored-by: Cheng Tang Co-authored-by: Cheng Tang --- docs/OperatorKernels.md | 57 +-- .../providers/cuda/cuda_execution_provider.cc | 475 +++++++----------- .../providers/cuda/reduction/reduction_ops.cc | 270 +++------- 3 files changed, 275 insertions(+), 527 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 84249df92231..d047096cb8c8 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -522,10 +522,8 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|11|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -696,39 +694,26 @@ Do not modify directly.* |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL1|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL2|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceLogSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceLogSumExp|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceMean|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ReduceProd|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| |ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceSumSquare|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 17]|**T** = tensor(double), tensor(float), tensor(float16)| |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 93e18d2940fc..4f5469ad8de3 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -645,51 +645,54 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); @@ -824,12 +827,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -843,45 +840,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); @@ -959,22 +917,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMin); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); @@ -1128,50 +1070,36 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); @@ -1270,13 +1198,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1329,6 +1257,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 18 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); + // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Cast); @@ -1594,51 +1528,51 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1777,12 +1711,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1796,45 +1727,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1908,22 +1800,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2077,50 +1953,36 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2219,13 +2081,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2277,6 +2139,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 2f057d53d560..d46ed9c245a8 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,140 +16,29 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -// opset 11 explicitly added support for negative axis. implementation already allowed it. -#define REGISTER_KERNEL_TYPED(name, T) \ +#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 12, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register those with changes in OpSet12. -#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register ReduceMin int64_t support in OpSet14. -#define REGISTER_KERNEL_TYPED_14(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 14, \ + 1, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet -#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + version, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); -// Register with the latest version 13 -#define REGISTER_KERNEL_TYPED_13(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ + REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ + REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -917,69 +806,76 @@ template std::unique_ptr ReduceCompute Date: Fri, 27 Oct 2023 10:29:27 +0800 Subject: [PATCH 22/36] [ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959) This PR is to support efficient attention and flash attention in ORTModule, including: - Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable. - Integrate Triton Flash attention, which requires triton==2.0.0.dev20221202. Need A100 or H100. ORTMODULE_USE_FLASH_ATTENTION=1 to enable. - A python transformer tool to match sub-graph by config and write transformer quickly. Current transformers supports attention mask for both efficient attn and flash attn, and dropout for efficient attn only. To support more training scenarios (such as causal mask in GPT2), more transformers need to be added. The feature is guarded by system environment variables, it won't effect any current behavior if not enabled. Since it requires specific PyTorch/Triton versions, related tests is not added for now. --- cmake/onnxruntime_python.cmake | 7 + .../contrib_ops/cpu/aten_ops/aten_op.cc | 6 +- .../cpu/aten_ops/aten_op_executor.h | 16 +- .../core/framework/fallback_cpu_capability.cc | 3 +- onnxruntime/core/framework/utils.cc | 27 +- onnxruntime/core/framework/utils.h | 1 + .../core/optimizer/transformer_memcpy.cc | 8 +- .../python/onnxruntime_pybind_state.cc | 10 +- .../aten_op_executor/__init__.py | 2 +- .../aten_op_executor/aten_op_executor.cc | 38 +- .../ort_torch_ext/__init__.py | 4 +- .../core/graph/training_op_defs.cc | 1 + .../training/ort_triton/kernel/__init__.py | 21 +- .../training/ort_triton/kernel/_flash_attn.py | 1244 +++++++++++++++++ .../training/ort_triton/kernel/_slice_scel.py | 6 +- .../python/training/ortmodule/__init__.py | 3 +- .../training/ortmodule/_training_manager.py | 4 +- .../ortmodule/graph_optimizer_registry.py | 47 + .../ortmodule/graph_optimizers/__init__.py | 15 + .../ortmodule/graph_optimizers/_aten_attn.py | 414 ++++++ .../ortmodule/graph_optimizers/utils.py | 178 +++ .../ortmodule/graph_transformer_registry.py | 47 - .../training_ops/cpu/triton/triton_op.cc | 10 +- .../training_ops/cpu/triton/triton_op.h | 16 + pyproject.toml | 1 + setup.py | 1 + 26 files changed, 2037 insertions(+), 93 deletions(-) create mode 100644 orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py create mode 100644 orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py create mode 100644 orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py create mode 100644 orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py create mode 100644 orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py delete mode 100644 orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index bf9adbaefabc..a9a78668b481 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*" ) + file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*" + ) file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py" ) @@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/graph_optimizers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton/kernel COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils @@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_graph_optimizers_srcs} + $/onnxruntime/training/ortmodule/graph_optimizers/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ort_triton_srcs} $/onnxruntime/training/ort_triton/ diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc index 945c3aebce57..d0abf58922f8 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc @@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const { aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size, dlpack_outputs.get()); for (size_t i = 0; i < output_size; ++i) { - ORT_RETURN_IF_ERROR( - p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + if (dlpack_outputs[i]) { + ORT_RETURN_IF_ERROR( + p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index be9650d96b00..d72868cd8fa9 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index); +typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); - p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); + void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); + p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) { - ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); + bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; + IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 3d971e6aa29a..ef68b88187e0 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -9,6 +9,7 @@ #include "onnx/defs/data_type_utils.h" #include "core/framework/op_kernel.h" +#include "core/framework/utils.h" using namespace ONNX_NAMESPACE::Utils; @@ -77,7 +78,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe ORT_THROW_IF_ERROR(node->ForEachWithIndex( node->OutputDefs(), [&](const NodeArg& node_arg, size_t out_index) { - if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) { + if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) { cpu_output_args.insert(&node_arg); auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index d63881ab4ff0..23fe5e1cd3d9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index); + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + } +#else + ORT_UNUSED_PARAMETER(node); +#endif + + return false; +} + +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) { + if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) { + return true; + } + +#ifdef ENABLE_ATEN + if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && + node.Domain() == kPytorchAtenDomain) { + const auto& attrs = node.GetAttributes(); + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); + std::string op_name = attrs.at("operator").s(); + std::string overload_name = ""; + if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) { + overload_name = attrs.at("overload_name").s(); + } + + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index ea6a629f87cb..f0b1b9109d40 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet bool sync_subgraph_fetches = false); bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); template constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index ed3e35706b68..0d7ab70eba61 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg if (!arg->Exists()) continue; - if (kci && kci->kernel_def->IsOutputOnCpu(i)) + if (utils::IsOutputOnCpu(node, kci, i)) non_provider_output_defs_.insert(arg); else provider_output_defs_.insert(arg); @@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } if (arg_output_index != -1) { - if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it); + if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it); } } } @@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker // normally initializers are only inputs, but things may change with ops like assign ORT_THROW_IF_ERROR(Node::ForEachWithIndex( p_node->OutputDefs(), - [kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { - if (kci->kernel_def->IsOutputOnCpu(index)) { + [kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { + if (utils::IsOutputOnCpu(*p_node, kci, index)) { ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end()); } return Status::OK(); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a72f56360151..90271b545839 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1214,14 +1214,14 @@ void addGlobalMethods(py::module& m) { #ifdef ENABLE_ATEN m.def("register_aten_op_executor", - [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_tensor_argument_address_int, aten_op_executor_address_int; + [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_cpu_argument_address_int, aten_op_executor_address_int; ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); + ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int)); ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); + void* p_is_cpu_argument = reinterpret_cast(is_cpu_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor); }); #endif } diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 9dee6564509d..8bf7cbf80eb3 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor _C.register_aten_op_executor( - str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 182f2368f5b4..903a394a06ef 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -154,11 +154,32 @@ class ATenOperatorCache { std::unordered_map, ATenOperator, PairHash> ops_; }; -// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not. -bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) { - const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); - TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); - return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; +const std::unordered_map> kCpuTensorInputsMap = { + {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}}; + +const std::unordered_map> kCpuTensorOutputsMap = { + {"_efficient_attention_forward", {2, 3}}}; + +// Backend uses this function to check if an argument is CPU input or not. +bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { + if (is_input) { + // If the argument is non-tensor type, it's CPU argument. + const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); + TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); + if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) { + return true; + } + } + + std::string full_name = std::string(op_name); + std::string overload_name_str = std::string(overload_name); + if (overload_name_str != "") { + full_name += ("." + overload_name_str); + } + + const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap; + return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() && + cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end(); } void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, @@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t size_t output_index = 0; for (const auto& ret : torch::jit::pop(stack, output_size)) { const auto& tensor = ret.toTensor(); - dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()); + dlpack_outputs[output_index++] = + tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; } } -size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } +size_t is_cpu_argument_address() { return reinterpret_cast(&IsCpuArgument); } size_t execute_aten_operator_address() { return reinterpret_cast(&ExecuteATenOperator); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check."); + m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check."); m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor"); } diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py index 7d5716b85db3..329fba5aa670 100644 --- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -5,7 +5,7 @@ from onnxruntime.capi import _pybind_state as _C -from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address +from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address def run_once_aten_op_executor(f): @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): - _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) + _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address())) def init_aten_op_executor(): diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index c90acfdb7bb7..80d937fa163e 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4180,6 +4180,7 @@ Return true if all elements are true and false otherwise. .Attr("func_name", "Function name of the Python Triton kernel.", AttributeProto::STRING, std::string("")) .Attr("onnx_key", "The hash key for the ONNX graph.", AttributeProto::INT, static_cast(0)) .Attr("onnx_string", "The onnx string of the triton kernel.", AttributeProto::STRING, std::string("")) + .AllowUncheckedAttributes() .Input(0, "inputs", "Input tensors. If to call an existing Python Triton kernel, " "the input count and order should match the arguments of the function. If to compute an ONNX graph, " diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index 97318ea2e53a..c1b99e4859db 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -3,15 +3,28 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out -from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel +import os -__all__ = [ +from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 +from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401 + +_all_kernels = [ "triton_gemm", "triton_gemm_out", "triton_matmul", "triton_matmul_out", "slice_scel", "slice_scel_backward", - "transform_slice_scel", ] + +_all_optimizers = [ + "optimize_graph_for_slice_scel", +] + +if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: + from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401 + + _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) + _all_optimizers.append("optimize_graph_for_flash_attention") + +__all__ = _all_kernels + _all_optimizers # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py new file mode 100644 index 000000000000..40398b33d8f0 --- /dev/null +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py @@ -0,0 +1,1244 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math +from typing import List, Tuple + +import torch +import triton +import triton.language as tl +from onnx import GraphProto, NodeProto, TensorProto, helper + +from onnxruntime.training.ortmodule import register_graph_optimizer +from onnxruntime.training.ortmodule.graph_optimizers.utils import GraphMatcher, check_attribute_value, update_graph + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + Lse, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != "none": + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != "none": + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == "none": + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == "matrix": + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != "none": + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def flash_attn_forward(q, k, v, bias=None, **kwargs): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + + causal = kwargs.get("causal", 0) == 1 + softmax_scale = kwargs.get("softmax_scale", 1.0 / math.sqrt(d)) + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)") + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse + + +def flash_attn_backward(do, q, k, v, o, lse, bias=None, **kwargs): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + + causal = kwargs.get("causal", 0) == 1 + softmax_scale = kwargs.get("softmax_scale", 1.0 / math.sqrt(d)) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=128, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)") + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + bias, + do, + dq_accum, + dk, + dv, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + return dq, dk, dv + + +def _make_flash_attention_nodes( + idx: int, + q: str, + k: str, + v: str, + y: str, + dy: str, + dq: str, + dk: str, + dv: str, + bias: str, + scale: float, +): + logsumexp = helper.make_tensor_value_info("logsumexp_" + str(idx), TensorProto.FLOAT, []) + fwd_node = helper.make_node( + "TritonOp", + [q, k, v, bias], + [y, logsumexp.name], + "TritonOp_Flash_Attn_Fwd_" + str(idx), + None, + "com.microsoft", + func_name="flash_attn_forward", + causal=0, + softmax_scale=scale, + ) + bwd_node = helper.make_node( + "TritonOp", + [dy, q, k, v, y, logsumexp.name, bias], + [dq, dk, dv], + "TritonOp_Flash_Attn_Bwd_" + str(idx), + None, + "com.microsoft", + func_name="flash_attn_backward", + causal=0, + softmax_scale=scale, + ) + return [fwd_node, bwd_node], [logsumexp] + + +# Without causal mask, without Dropout. For example, BERT model in HuggingFace. +_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("Identity", False, [(10, 0, 0)]), # 11 + ("Div", False, [(11, 0, 0)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15 + ("FusedMatMul", False, [(5, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(14, 0, 0)]), # 18 + ("Transpose", False, [(15, 0, 0)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 +] + + +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[19].output[0], + nodes[20].output[0], + nodes[4].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + return nodes, nodes_to_add, new_value_infos + + +# llama2+peft, k doesn't require grad. +_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 1)]), # 1 + ("Div", False, [(0, 0, 0)]), # 2 + ("Add", False, [(2, 0, 0)]), # 3 + ("Softmax", False, [(3, 0, 0)]), # 4 + ("MatMul", False, [(4, 0, 0)]), # 5 + ("Transpose", True, [(5, 0, 1)]), # 6 + ("Identity", False, [(6, 0, 0)]), # 7 + ("YieldOp", False, [(7, 0, -1)]), # 8 + ("Transpose", False, [(5, 0, 0)]), # 9 + ("FusedMatMul", False, [(6, 0, 1)]), # 10 + ("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11 + ("Identity", False, [(11, 0, 0)]), # 12 + ("Div", False, [(12, 0, 0)]), # 13 + ("Identity", False, [(13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(4, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Sum", False, [(16, 0, 0)]), # 18 + ("Transpose", False, [(18, 0, 0)]), # 19 +] + + +def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[2].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 1, 3, 2]) + and scale_value is not None + and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + and matcher.get_consumer_count(nodes[14].output[0]) == 1 + ): + return [], [], [] + + dtype, _ = matcher.get_type_and_shape(nodes[0].input[0]) + assert dtype is not None + trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None) + trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None) + trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None) + trans_q = helper.make_node( + "Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_q_grad = helper.make_node( + "Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k = helper.make_node( + "Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3] + ) + nodes[19].input[0] = nodes[18].input[1] + v_grad = nodes[19].output[0] + nodes[19].output[0] = nodes[18].output[0] + nodes[18].input[1] = nodes[18].output[0] + nodes[18].output[0] = v_grad + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + trans_q_tensor.name, + trans_k_tensor.name, + nodes[6].input[0], + nodes[9].output[0], + nodes[17].input[0], + trans_q_grad_tensor.name, + "", + nodes[16].output[0], + nodes[3].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + nodes_to_remove = nodes[:6] + nodes[9:18] + nodes_to_add.extend([trans_q, trans_q_grad, trans_k]) + new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor]) + return nodes_to_remove, nodes_to_add, new_value_infos + + +# llama2+peft, k requires grad. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 1)]), # 1 + ("Div", False, [(0, 0, 0)]), # 2 + ("Add", False, [(2, 0, 0)]), # 3 + ("Softmax", False, [(3, 0, 0)]), # 4 + ("MatMul", False, [(4, 0, 0)]), # 5 + ("Transpose", True, [(5, 0, 1)]), # 6 + ("Identity", False, [(6, 0, 0)]), # 7 + ("YieldOp", False, [(7, 0, -1)]), # 8 + ("Transpose", False, [(5, 0, 0)]), # 9 + ("FusedMatMul", False, [(6, 0, 1)]), # 10 + ("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11 + ("Identity", False, [(11, 0, 0)]), # 12 + ("Div", False, [(12, 0, 0)]), # 13 + ("Identity", False, [(13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(14, 0, 1)]), # 16 + ("Transpose", False, [(16, 0, 0)]), # 17 + ("FusedMatMul", False, [(4, 0, 0)]), # 18 + ("Transpose", True, [(18, 0, 1)]), # 19 + ("Sum", False, [(18, 0, 0)]), # 20 + ("Transpose", False, [(20, 0, 0)]), # 21 +] + + +def _aptimize_for_pattern_2(matcher: GraphProto, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[2].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 1, 3, 2]) + and scale_value is not None + and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + and matcher.get_consumer_count(nodes[14].output[0]) == 2 + ): + return [], [], [] + + dtype, _ = matcher.get_type_and_shape(nodes[0].input[0]) + assert dtype is not None + trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None) + trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None) + trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None) + trans_k_grad_tensor = helper.make_tensor_value_info("trans_k_grad_" + str(idx), dtype, None) + trans_q = helper.make_node( + "Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_q_grad = helper.make_node( + "Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k = helper.make_node( + "Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k_grad = helper.make_node( + "Transpose", [trans_k_grad_tensor.name], [nodes[17].output[0]], "Trans_K_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + nodes[21].input[0] = nodes[20].input[1] + v_grad = nodes[21].output[0] + nodes[21].output[0] = nodes[20].output[0] + nodes[20].input[1] = nodes[20].output[0] + nodes[20].output[0] = v_grad + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + trans_q_tensor.name, + trans_k_tensor.name, + nodes[6].input[0], + nodes[9].output[0], + nodes[19].input[0], + trans_q_grad_tensor.name, + trans_k_grad_tensor.name, + nodes[18].output[0], + nodes[3].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + nodes_to_remove = nodes[:6] + nodes[9:20] + nodes_to_add.extend([trans_q, trans_q_grad, trans_k, trans_k_grad]) + new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor, trans_k_grad_tensor]) + return nodes_to_remove, nodes_to_add, new_value_infos + + +# TODO: add pattern to support attention with causal mask, such as GPT2 in HuggingFace. +_PATTERNS = [ + (_PATTERN_0, _optimize_for_pattern_0), + (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _aptimize_for_pattern_2), +] + + +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_flash_attention(graph: GraphProto): + nodes_to_remove = [] + nodes_to_add = [] + new_value_infos = [] + matcher = GraphMatcher(graph) + idx = 0 + for pattern_tuple in _PATTERNS: + for nodes in matcher.match_pattern(pattern_tuple[0]): + remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes) + if len(add_nodes) > 0: + nodes_to_remove.extend(remove_nodes) + nodes_to_add.extend(add_nodes) + new_value_infos.extend(add_value_infos) + idx += 1 + update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py index 8edcc9b63ef4..fb7ddc68900c 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py @@ -11,7 +11,7 @@ import triton.language as tl from onnx import TensorProto, helper -from onnxruntime.training.ortmodule import register_graph_transformer +from onnxruntime.training.ortmodule import register_graph_optimizer from .._utils import get_attribute, to_numpy_array @@ -246,8 +246,8 @@ def _get_shape_related_nodes(graph, start_arg, sub_graph_nodes): args.append(output) -@register_graph_transformer(devices="cuda") -def transform_slice_scel(graph): +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_slice_scel(graph): remove_nodes = [] triton_nodes = [] value_infos = [] diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 59cf05bb082f..fbf1b7c2bac4 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -124,7 +124,8 @@ def _are_deterministic_algorithms_enabled(): return ORTMODULE_IS_DETERMINISTIC -from .graph_transformer_registry import register_graph_transformer # noqa: E402, F401 +from .graph_optimizer_registry import register_graph_optimizer # noqa: E402, F401 +from .graph_optimizers import * # noqa: E402, F403 from .options import DebugOptions, LogLevel # noqa: E402, F401 # ORTModule must be loaded only after all validation passes diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3953d342f189..e0f11e5aa407 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -21,7 +21,7 @@ from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results -from .graph_transformer_registry import GraphTransformerRegistry +from .graph_optimizer_registry import GraphOptimizerRegistry from .options import DebugOptions, _SkipCheck @@ -369,7 +369,7 @@ def _build_graph(self, graph_transformer_config): device_type = self._device.type if device_type == "cuda" and self.is_rocm_pytorch: device_type = "rocm" - GraphTransformerRegistry.transform_all( + GraphOptimizerRegistry.optimize_all( type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph ) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py new file mode 100644 index 000000000000..897ecac148bf --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Callable + +from onnx.onnx_ml_pb2 import GraphProto + + +class GraphOptimizerRegistry: + _OPTIMIZER_FUNCS = {} # noqa: RUF012 + + @classmethod + def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]): + modules = [] + if target_modules == "all": + modules.append("all") + else: + modules = target_modules.split("|") + for module in modules: + if module in cls._OPTIMIZER_FUNCS: + cls._OPTIMIZER_FUNCS[module].append((fn, devices, priority)) + else: + cls._OPTIMIZER_FUNCS[module] = [(fn, devices, priority)] + + @classmethod + def optimize_all(cls, module_name: str, device: str, graph: GraphProto): + optimizers_to_apply = [] + if "all" in cls._OPTIMIZER_FUNCS: + optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS["all"]) + if module_name in cls._OPTIMIZER_FUNCS: + optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS[module_name]) + optimizers_to_apply = [x for x in optimizers_to_apply if x[1] == "all" or device in x[1]] + optimizers_to_apply.sort(key=lambda x: x[2], reverse=True) + for fn, _, _ in optimizers_to_apply: + fn(graph) + + +# target_modules can be multiple module names separated by "|", or "all" means apply to all modules. +# devices can be multiple device types separated by "|" or "all" means apply to all devices. +def register_graph_optimizer(target_modules: str = "all", devices: str = "all", priority: int = 0): + def graph_optimizer_wrapper(fn): + GraphOptimizerRegistry.register(target_modules, devices, priority, fn) + return fn + + return graph_optimizer_wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py new file mode 100644 index 000000000000..d215e12f8137 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os + +_all_optimizers = [] + +if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1: + from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 + + _all_optimizers.append("optimize_graph_for_aten_efficient_attention") + +__all__ = _all_optimizers # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py new file mode 100644 index 000000000000..94bd41293b42 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -0,0 +1,414 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation +is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to +run is you are using PyTorch with older versions. + +PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add +support if we want to try in the future. +""" + +from typing import List, Tuple + +from onnx import GraphProto, NodeProto, TensorProto, helper + +from ..graph_optimizer_registry import register_graph_optimizer +from .utils import GraphMatcher, check_attribute_value, make_constant_node, update_graph + + +def _make_efficient_attention_nodes( + idx: int, + q: str, + k: str, + v: str, + y: str, + dy: str, + dq: str, + dk: str, + dv: str, + bias: str, + expand_bias: bool, + scale: float, + dropout_ratio: float, + causal: bool, +): + nodes_to_add = [] + scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) + dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) + causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) + int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) + true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) + false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) + logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, []) + seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, []) + offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, []) + new_value_infos = [logsumexp, seed, offset] + if expand_bias: + shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1) + shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3) + shape_2 = helper.make_node("Shape", [q], ["shape_2_" + str(idx)], start=1, end=2) + shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2) + concat = helper.make_node( + "Concat", + ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)], + ["concated_shape_" + str(idx)], + axis=0, + ) + expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)]) + nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand]) + bias = "expanded_bias_" + str(idx) + fwd_node = helper.make_node( + "ATen", + [ + q, + k, + v, + bias, + "", + "", + "", + dropout_ratio_node.output[0], + causal_node.output[0], + true_node.output[0], + scale_node.output[0], + "", + "", + ], + [y, logsumexp.name, seed.name, offset.name], + "efficient_attention_forward_" + str(idx), + None, + "org.pytorch.aten", + operator="_efficient_attention_forward", + ) + bwd_node = helper.make_node( + "ATen", + [ + dy, + q, + k, + v, + bias, + y, + "", + "", + int_zero_node.output[0], + int_zero_node.output[0], + logsumexp.name, + dropout_ratio_node.output[0], + seed.name, + offset.name, + causal_node.output[0], + false_node.output[0], + scale_node.output[0], + "", + ], + [dq, dk, dv, ""], + "efficient_attention_backward_" + str(idx), + None, + "org.pytorch.aten", + operator="_efficient_attention_backward", + ) + nodes_to_add.extend( + [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] + ) + return nodes_to_add, new_value_infos + + +# Without causal mask, with Dropout. For example, BERT model in HuggingFace. +_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("Dropout", False, [(5, 0, 0)]), # 6 + ("MatMul", False, [(6, 0, 0)]), # 7 + ("Transpose", True, [(7, 0, 1)]), # 8 + ("Transpose", False, [(7, 0, 0)]), # 9 + ("FusedMatMul", False, [(8, 0, 1)]), # 10 + ("DropoutGrad", False, [(10, 0, 0), (6, 1, 1)]), # 11 + ("SoftmaxGrad_13", False, [(11, 0, 0), (5, 0, 1)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("Div", False, [(13, 0, 0)]), # 14 + ("Identity", False, [(14, 0, 0)]), # 15 + ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]), # 16 + ("FusedMatMul", False, [(1, 0, 0), (15, 0, 1)]), # 17 + ("FusedMatMul", False, [(6, 0, 0)]), # 18 + ("Transpose", True, [(18, 0, 1)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 + ("Transpose", False, [(17, 0, 0)]), # 21 + ("Transpose", False, [(18, 0, 0)]), # 22 +] + + +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + ratio_value = matcher.get_constant_value(nodes[6].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and ratio_value is not None + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + _, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0]) + _, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1]) + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[8].input[0], + nodes[9].output[0], + nodes[19].input[0], + nodes[20].output[0], + nodes[21].output[0], + nodes[22].output[0], + nodes[4].input[1], + add_input_shape_0 != add_input_shape_1, + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ratio_value, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Without causal mask, without Dropout. For example, BERT model and disabling attention dropout in HuggingFace. +_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("Identity", False, [(10, 0, 0)]), # 11 + ("Div", False, [(11, 0, 0)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15 + ("FusedMatMul", False, [(5, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(14, 0, 0)]), # 18 + ("Transpose", False, [(15, 0, 0)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 +] + + +def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + _, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0]) + _, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1]) + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[19].output[0], + nodes[20].output[0], + nodes[4].input[1], + add_input_shape_0 != add_input_shape_1, + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# No causal mask, no attention mask, without Dropout. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Softmax", False, [(0, 0, 0)]), # 7 + ("Cast", False, [(7, 0, 0)]), # 8 + ("MatMul", False, [(8, 0, 0)]), # 9 + ("Transpose", True, [(9, 0, 1)]), # 10 + ("Transpose", False, [(9, 0, 0)]), # 11 + ("FusedMatMul", False, [(10, 0, 1)]), # 12 + ("Cast", False, [(12, 0, 0)]), # 13 + ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 + ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 + ("Mul", False, [(15, 0, 0)]), # 17 + ("Mul", False, [(16, 0, 0)]), # 18 + ("Identity", False, [(17, 0, 0)]), # 19 + ("Identity", False, [(18, 0, 0)]), # 20 + ("Cast", False, [(19, 0, 0)]), # 21 + ("Cast", False, [(20, 0, 0)]), # 22 + ("Transpose", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("FusedMatMul", False, [(8, 0, 0)]), # 25 + ("Transpose", True, [(25, 0, 1)]), # 26 + ("Transpose", False, [(25, 0, 0)]), # 27 +] + + +def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[8], "to", 10) + and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[10].input[0], + nodes[11].output[0], + nodes[26].input[0], + nodes[23].output[0], + nodes[24].output[0], + nodes[27].output[0], + "", + False, + scale_value_1, + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Has causal mask, no attention mask, without Dropout. +_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Add", False, [(0, 0, 0)]), # 7 + ("Cast", True, [(7, 0, 1)]), # 8 + ("Slice", True, [(8, 0, 0)]), # 9 + ("Slice", True, [(9, 0, 0)]), # 10 + ("Unsqueeze", True, [(9, 0, 2)]), # 11 + ("Gather", True, [(11, 0, 0)]), # 12 + ("Shape", True, [(12, 0, 0)]), # 13 + ("Softmax", False, [(7, 0, 0)]), # 14 + ("Cast", False, [(14, 0, 0)]), # 15 + ("MatMul", False, [(15, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(17, 0, 1)]), # 19 + ("Cast", False, [(19, 0, 0)]), # 20 + ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 + ("Identity", False, [(21, 0, 0)]), # 22 + ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 + ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 + ("Mul", False, [(23, 0, 0)]), # 25 + ("Mul", False, [(24, 0, 0)]), # 26 + ("Identity", False, [(25, 0, 0)]), # 27 + ("Identity", False, [(26, 0, 0)]), # 28 + ("Cast", False, [(27, 0, 0)]), # 29 + ("Cast", False, [(28, 0, 0)]), # 30 + ("Transpose", False, [(29, 0, 0)]), # 31 + ("Transpose", False, [(30, 0, 0)]), # 32 + ("FusedMatMul", False, [(15, 0, 0)]), # 33 + ("Transpose", True, [(33, 0, 1)]), # 34 + ("Transpose", False, [(33, 0, 0)]), # 35 +] + + +def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[15], "to", 10) + and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[34].input[0], + nodes[31].output[0], + nodes[32].output[0], + nodes[35].output[0], + "", + False, + scale_value_1, + 0.0, + True, + ) + return nodes, nodes_to_add, new_value_infos + + +_PATTERNS = [ + (_PATTERN_0, _optimize_for_pattern_0), + (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _optimize_for_pattern_2), + (_PATTERN_3, _optimize_for_pattern_3), +] + + +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_aten_efficient_attention(graph: GraphProto): + nodes_to_remove = [] + nodes_to_add = [] + new_value_infos = [] + matcher = GraphMatcher(graph) + idx = 0 + for pattern_tuple in _PATTERNS: + for nodes in matcher.match_pattern(pattern_tuple[0]): + remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes) + if len(add_nodes) > 0: + nodes_to_remove.extend(remove_nodes) + nodes_to_add.extend(add_nodes) + new_value_infos.extend(add_value_infos) + idx += 1 + update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py new file mode 100644 index 000000000000..e6e5ce56773e --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py @@ -0,0 +1,178 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import itertools +from typing import Any, Dict, List, Sequence, Tuple + +import numpy as np +from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper + + +def _get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) -> Any: + """Get attribute value from node by attribute key.""" + found = [attr for attr in node.attribute if attr.name == attr_name] + if found: + return helper.get_attribute_value(found[0]) + return default_value + + +def _to_numpy_array(node: Any) -> np.ndarray: + """Convert Constant node or TensorProto to Python value.""" + tensor = node + if isinstance(node, NodeProto): + tensor = _get_attribute(node, "value") + assert isinstance(tensor, TensorProto) + return numpy_helper.to_array(tensor).tolist() + + +class GraphMatcher: + """Sub-graph matcher with given pattern. + + GraphMatcher takes an ONNX graph to initialize. It tries to match sub-graphs to a given pattern and yield + matched sub-graphs (a list of matched nodes for each sub-graph) one by one. + + Pattern is described by a list. Each entry of the list is a Tuple: + + Tuple[str, bool, List[Tuple[int, int, int]]], e.g., ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]) + + * First string is the Op type, e.g., "FusedMatMul". + * Second bool indicates it's producer node or consumer node for source node. + * There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers, + first integer is the index of the target node in the list, second integer is the output index of the edge, + and thrid integer is the input index of the edge. + + For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also + matches rest edge infos. + + Note that when lookup target node, it will only take the first matched node as target node. For example, if a source + node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned. + You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to + avoid such confusion if possible. + """ + + def __init__(self, graph: GraphProto): + self._graph: GraphProto = graph + self._op_type_to_nodes: Dict[str, List[NodeProto]] = {} + self._consumer_count: Dict[str, int] = {} + for node in graph.node: + if node.op_type not in self._op_type_to_nodes: + self._op_type_to_nodes[node.op_type] = [] + self._op_type_to_nodes[node.op_type].append(node) + for input in node.input: + self._consumer_count[input] = self._consumer_count.get(input, 0) + 1 + + def _get_producer(self, arg: str, op_type: str, output_idx: int): + for node in self._op_type_to_nodes.get(op_type, []): + if (output_idx >= 0 and len(node.output) > output_idx and node.output[output_idx] == arg) or ( + output_idx == -1 and arg in node.output + ): + return node + return None + + def _get_consumer(self, arg: str, op_type: str, input_idx: int): + for node in self._op_type_to_nodes.get(op_type, []): + if (input_idx >= 0 and len(node.input) > input_idx and node.input[input_idx] == arg) or ( + input_idx == -1 and arg in node.input + ): + return node + return None + + def get_consumer_count(self, arg: str): + return self._consumer_count.get(arg, 0) + + def get_constant_value(self, arg: str): + node_or_initializer = None + if "Constant" in self._op_type_to_nodes: + for node in self._op_type_to_nodes["Constant"]: + if arg in node.output: + node_or_initializer = node + break + if node_or_initializer is None: + for initializer in self._graph.initializer: + if arg == initializer.name: + node_or_initializer = initializer + break + if node_or_initializer is None: + return None + return _to_numpy_array(node_or_initializer) + + def get_type_and_shape(self, arg: str): + value_infos = [ + value_info + for value_info in itertools.chain(self._graph.input, self._graph.value_info) + if value_info.name == arg + ] + if len(value_infos) > 0 and value_infos[0].type.tensor_type.HasField("shape"): + shape = [] + for dim in value_infos[0].type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + return value_infos[0].type.tensor_type.elem_type, shape + initializers = [initializer for initializer in self._graph.initializer if initializer.name == arg] + if len(initializers) > 0: + return initializers[0].data_type, initializers[0].dims + return None, None + + def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + nodes = [node] + for i in range(1, len(pattern)): + next_op_type = pattern[i][0] + is_producer = pattern[i][1] + node_idx, output_idx, input_idx = pattern[i][2][0] + next_node = ( + self._get_producer(nodes[node_idx].input[input_idx], next_op_type, output_idx) + if is_producer + else self._get_consumer(nodes[node_idx].output[output_idx], next_op_type, input_idx) + ) + if next_node is None: + return [] + for j in range(1, len(pattern[i][2])): + node_idx, output_idx, input_idx = pattern[i][2][j] + assert output_idx >= 0 and input_idx >= 0 + if (not is_producer and nodes[node_idx].output[output_idx] != next_node.input[input_idx]) or ( + is_producer and next_node.output[output_idx] != nodes[node_idx].input[input_idx] + ): + return [] + nodes.append(next_node) + return nodes + + def match_pattern(self, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + for node in self._op_type_to_nodes.get(pattern[0][0], []): + result = self._match_pattern(node, pattern) + if len(result) == len(pattern): + yield result + + +def check_attribute_value(node: NodeProto, attr_name: str, expected_value: Any): + """Check if the attribute of given node has expected value.""" + value = _get_attribute(node, attr_name) + return value == expected_value + + +def make_constant_node(name: str, dtype: TensorProto.DataType, dims: Sequence[int], vals: Any): + """Create a constant node with given constant tensor (data type, shape, and data).""" + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=dtype, dims=dims, vals=vals), + ) + + +def update_graph( + graph: GraphProto, + nodes_to_remove: List[NodeProto], + nodes_to_add: List[NodeProto], + new_value_infos: List[TensorProto] = [], # noqa: B006 +): + """Update an ONNX graph by removing some nodes, and adding some new nodes and value infos.""" + nodes = [node for node in graph.node if node not in nodes_to_remove] + nodes.extend(nodes_to_add) + graph.ClearField("node") + graph.node.extend(nodes) + if len(new_value_infos) > 0: + graph.value_info.extend(new_value_infos) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py deleted file mode 100644 index 70056179c140..000000000000 --- a/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py +++ /dev/null @@ -1,47 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from typing import Callable - -from onnx.onnx_ml_pb2 import GraphProto - - -class GraphTransformerRegistry: - _TRANSFORMER_FUNCS = {} # noqa: RUF012 - - @classmethod - def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]): - modules = [] - if target_modules == "all": - modules.append("all") - else: - modules = target_modules.split("|") - for module in modules: - if module in cls._TRANSFORMER_FUNCS: - cls._TRANSFORMER_FUNCS[module].append((fn, devices, priority)) - else: - cls._TRANSFORMER_FUNCS[module] = [(fn, devices, priority)] - - @classmethod - def transform_all(cls, module_name: str, device: str, graph: GraphProto): - transformers_to_apply = [] - if "all" in cls._TRANSFORMER_FUNCS: - transformers_to_apply.extend(cls._TRANSFORMER_FUNCS["all"]) - if module_name in cls._TRANSFORMER_FUNCS: - transformers_to_apply.extend(cls._TRANSFORMER_FUNCS[module_name]) - transformers_to_apply = [x for x in transformers_to_apply if x[1] == "all" or device in x[1]] - transformers_to_apply.sort(key=lambda x: x[2], reverse=True) - for fn, _, _ in transformers_to_apply: - fn(graph) - - -# target_modules can be multiple module names separated by "|", or "all" means apply to all modules. -# devices can be multiple device types separated by "|" or "all" means apply to all devices. -def register_graph_transformer(target_modules: str = "all", devices: str = "all", priority: int = 0): - def graph_transformer_wrapper(fn): - GraphTransformerRegistry.register(target_modules, devices, priority, fn) - return fn - - return graph_transformer_wrapper diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc b/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc index 28f4ff665f79..c230a0c9a3b1 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc @@ -17,8 +17,8 @@ InlinedHashSet TritonOp::GetBoolOutputs(size_t output_size) const { InlinedHashSet bool_outputs; for (size_t i = 0; i < output_size; ++i) { ORT_ENFORCE(i < Node().OutputDefs().size(), "Output index out of range."); - if (Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == - ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) { + if (Node().OutputDefs()[i]->Exists() && Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) { bool_outputs.insert(i); } } @@ -37,13 +37,15 @@ Status TritonOp::Compute(OpKernelContext* context) const { InlinedHashSet bool_outputs = GetBoolOutputs(output_size); auto& executor = training::framework::triton::TritonOpExecutor::Instance(); if (func_name_ != "") { - executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs); + executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs, kwargs_); } else { executor.ExecuteByOnnx(onnx_key_, onnx_string_, inputs, outputs, bool_outputs); } ORT_ENFORCE(output_size == outputs.size()); for (size_t i = 0; i < output_size; ++i) { - ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), outputs[i])); + if (Node().OutputDefs()[i]->Exists()) { + ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), outputs[i])); + } } return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index 25e7b1f15ff6..f226db76f7ed 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -5,6 +5,8 @@ #pragma once +#include "core/common/inlined_containers.h" + #ifndef SHARED_PROVIDER #include "core/framework/op_kernel.h" #endif @@ -18,6 +20,19 @@ class TritonOp final : public OpKernel { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &func_name_)); ORT_THROW_IF_ERROR(info.GetAttr("onnx_key", &onnx_key_)); ORT_THROW_IF_ERROR(info.GetAttr("onnx_string", &onnx_string_)); + for (const auto& attr : info.node().GetAttributes()) { + if (attr.first.rfind("_", 0) == 0 || attr.first == "func_name" || attr.first == "onnx_key" || + attr.first == "onnx_string") { + continue; + } + // Support int64 and float only for now, skip other types. + if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { + kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { + kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } + } } Status Compute(OpKernelContext* context) const override; @@ -28,6 +43,7 @@ class TritonOp final : public OpKernel { std::string func_name_; int64_t onnx_key_; std::string onnx_string_; + InlinedHashMap> kwargs_; }; bool IsTritonOpExecutorInitialized(); diff --git a/pyproject.toml b/pyproject.toml index 89011a7944ab..97515cb9fa62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,3 +92,4 @@ unfixable = [ "tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix "onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix "onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix +"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo. diff --git a/setup.py b/setup.py index b71836e0ee6e..f6308c56d059 100644 --- a/setup.py +++ b/setup.py @@ -466,6 +466,7 @@ def finalize_options(self): "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", + "onnxruntime.training.ortmodule.graph_optimizers", "onnxruntime.training.ort_triton", "onnxruntime.training.ort_triton.kernel", "onnxruntime.training.utils", From 0f3a067d3a7d42bbc860b97d7ed39e5d7cdd5d47 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Fri, 27 Oct 2023 11:29:55 +0800 Subject: [PATCH 23/36] [FIX] reorder initializer (#18097) ### Description Fix building error when with collective ops: error is thrown because `device_mesh_axis` will be initialized after `cond`. --- onnxruntime/contrib_ops/cuda/collective/sharding_spec.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 451d44b4bd43..6bdf5699c268 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -117,7 +117,7 @@ class AxisPartitionSpec { // A normal ctor. // TODO(wechi): Consider to hide it and revise the `public` members/functions // exposed to the user. - AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : device_mesh_axis(device_mesh_axis_), cond(cond_) {} + AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : cond(cond_), device_mesh_axis(device_mesh_axis_) {} // Helper to debug and generate error message; e.g., // "RS[0]". From b79ea7481930ae09980f61982a8b7b19303a0eca Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Thu, 26 Oct 2023 21:54:23 -0700 Subject: [PATCH 24/36] Add updates to LLaMA scripts (#18076) ### Description This PR adds a few updates to scripts in the LLaMA folder: - Fixes the precision re-naming in the LLaMA export - Adds a "prerequisites" section in the README - Adds IO binding synchronizations during benchmarking for other EPs ### Motivation and Context - With precision re-naming, the LLaMA parity check does not produce errors when creating the FP32 CPU model - The "prerequisites" section shows that there are specific package versions needed - This allows for benchmarking with other EPs besides CPU and CUDA --- .../tools/transformers/convert_generation.py | 2 +- .../tools/transformers/models/llama/README.md | 48 +++++++++---------- .../transformers/models/llama/benchmark.py | 36 ++++++++++---- .../models/llama/convert_to_onnx.py | 3 +- .../transformers/models/llama/llama_parity.py | 4 +- .../models/llama/requirements-cpu.txt | 2 +- .../models/llama/requirements-cuda.txt | 2 +- 7 files changed, 59 insertions(+), 38 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 4228c892d03a..b32ae64c5b0c 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1275,7 +1275,7 @@ def find_past_seq_len_usage(subg: GraphProto): def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): - # Replace model input for past sequence length + # Add model input for past sequence length new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) model.model.graph.input.append(new_input) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 6057b46667fe..9619e6cb52a9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -1,5 +1,18 @@ # LLaMA-2 +## Prerequisites + +Please note the package versions needed for using LLaMA-2 in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running LLaMA-2 on CPU +- `requirements-cuda.txt` + - For running LLaMA-2 on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements-quant.txt` + - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements.txt` + - Package versions needed in each of the above files + ## Exporting LLaMA-2 There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example). @@ -40,7 +53,7 @@ Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onn ### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) -Note that this will produce two ONNX models whereas the above two options produce one ONNX model. +Note that this may produce two ONNX models with older Optimum versions. The above two options produce one ONNX model and installing Optimum from source will now produce one ONNX model. First, log into the Hugging Face CLI in your terminal: @@ -81,7 +94,7 @@ Export for FP32 CUDA $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda ``` Export for FP32 CPU @@ -90,7 +103,7 @@ Export for FP32 CPU $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu ``` Export for FP16 CUDA @@ -105,10 +118,10 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama Export for INT8 CPU (SmoothQuant) ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged ``` Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models. @@ -128,7 +141,7 @@ Export for INT4 CUDA $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda ``` Export for INT4 CPU @@ -137,7 +150,7 @@ Export for INT4 CPU $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` ## Benchmark LLaMA-2 @@ -183,20 +196,7 @@ python3 -m models.llama.benchmark \ --auth ``` -4. Optimum + ONNX Runtime, FP16, export via convert_to_onnx -``` -python3 -m models.llama.benchmark \ - --benchmark-type hf-ort \ - --hf-ort-dir-path ./llama2-7b-fp16/ \ - --model-name meta-llama/Llama-2-7b-hf \ - --precision fp16 \ - --batch-sizes "1 2" \ - --sequence-lengths "8 16" \ - --device cuda \ - --auth -``` - -5. ONNX Runtime, FP32, Microsoft custom export +4. ONNX Runtime, FP32, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -208,7 +208,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -6. ONNX Runtime, FP16, Microsoft custom export +5. ONNX Runtime, FP16, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -220,7 +220,7 @@ python3 -m models.llama.benchmark \ --device cuda ``` -7. ONNX Runtime, FP32, convert_to_onnx +6. ONNX Runtime, FP32, convert_to_onnx ``` python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ @@ -232,7 +232,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -8. ONNX Runtime, FP16, convert_to_onnx +7. ONNX Runtime, FP16, convert_to_onnx ``` python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 976de2abc7c5..a721979eb0bc 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -286,31 +286,50 @@ def time_fn(args, fn, inputs): outputs = fn(inputs) logger.info(outputs) + input_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_inputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + + output_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_outputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + for _ in warmup_range: + input_sync() fn(inputs) + output_sync() # Benchmark - if args.device != "cpu": - torch.cuda.synchronize() - start_time = time.time() - + total_time = 0 bench_range = ( range(args.num_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: + input_sync() + start_time = time.time() + fn(inputs) - if args.device != "cpu": - torch.cuda.synchronize() - end_time = time.time() + output_sync() + end_time = time.time() + + total_time += end_time - start_time # Newline print after trange in order to print metrics on new lines without progress bar on same line if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") - latency = (end_time - start_time) / args.num_runs + latency = total_time / args.num_runs throughput = args.batch_size / latency logger.info(f"Batch Size: {args.batch_size}") @@ -467,6 +486,7 @@ def prepare_ort_inputs(inputs): else: io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) + setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding return inputs diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 61d71bc38f4e..69603fd3ed48 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -817,7 +817,8 @@ def main(): # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" - if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + if args.precision in {Precision.INT8, Precision.FLOAT32} + or (args.precision == Precision.INT4 and args.execution_provider == "cpu") else "fp16" ) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 6bfcb9b4f290..4353d0606803 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -113,10 +113,10 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama if args.execution_provider != "cpu": io_binding = add_io_bindings(args, ort_model, inputs) - torch.cuda.synchronize() + io_binding.synchronize_inputs() start_time = time.time() ort_model.run_with_iobinding(io_binding) - torch.cuda.synchronize() + io_binding.synchronize_outputs() end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt index e06c3ada834b..3d707fa13e3c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt @@ -1,2 +1,2 @@ -r requirements.txt -onnxruntime>=1.17.0 \ No newline at end of file +onnxruntime>=1.16.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index 773680937bd2..b634bcc50f6e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt # Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.17.0 \ No newline at end of file +onnxruntime-gpu>=1.16.2 \ No newline at end of file From 9c323106735535b6dab6b476648faac0ad185e21 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 22:33:42 -0700 Subject: [PATCH 25/36] Distributed Reshape Implementation (#18068) This DistributedReshape aims at supporting all sharding patterns encountered in llama 2. All patterns found are tested in `TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR implements algorithms to compute the categories below. - All inputs and outputs are replica, so it's computed like a normal Reshape. - Two-axis fusion (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`. - Two-axis decomposition (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch x seq, hidden] -> [batch, seq, hidden]`. Review guideline: - Ignore the changes in sharding_spec.h and sharding_spec.cc since they come from another PR #18025. - First, read onnxruntime_test_distributed.py to get familiar with the input/output of DistributedReshape. - Second, check the new APIs in reshape.h/reshape.cc to expose CUDA Reshape kernel to DistributedReshape. - For DistributedReshape, check its `ComputeInternal` for the 3 categories mentioned above. --- cmake/onnxruntime_providers_cuda.cmake | 3 +- cmake/onnxruntime_rocm_hipify.cmake | 2 + .../cuda/collective/distributed_reshape.cc | 861 ++++++++++++++++++ .../cuda/collective/distributed_reshape.h | 40 + .../contrib_ops/cuda/collective/sharding.cc | 8 +- .../cuda/collective/sharding_spec.cc | 14 +- .../cuda/collective/sharding_spec.h | 108 ++- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 + .../core/graph/contrib_ops/collective_defs.cc | 45 + onnxruntime/core/providers/cuda/cuda_kernel.h | 10 +- .../core/providers/cuda/tensor/reshape.cc | 75 ++ .../core/providers/cuda/tensor/reshape.h | 59 +- onnxruntime/core/providers/rocm/rocm_kernel.h | 10 +- .../python/onnxruntime_test_distributed.py | 667 ++++++++++++++ 14 files changed, 1870 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 003012f8da07..02b17ee324f4 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -38,6 +38,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -246,4 +247,4 @@ install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index de1458c12001..4ef0584b0273 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -103,6 +103,8 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/sharding.cc") list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc new file mode 100644 index 000000000000..a0ac40defbee --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc @@ -0,0 +1,861 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reshape.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +// Return true if src_shape[src_begin:src_end] is the same as +// dst_shape[dst_begin:dst_end]. Otherwise, return false. +// TODO: replace std::vector with gsl::span. +bool CompareSubVectors( + const std::vector& src_shape, + const std::vector& dst_shape, + size_t src_begin, size_t src_end, + size_t dst_begin, size_t dst_end) { + if (src_end - src_begin != dst_end - dst_begin) { + // Sub-vectors have different lengths. + return false; + } + for (size_t src_index = src_begin, dst_index = dst_begin; + src_index < src_end && dst_index < dst_end; + ++src_index, ++dst_index) { + if (src_shape[src_index] != dst_shape[dst_index]) { + // Sub-vectors have different elements. + return false; + } + } + // Sub-vectors have same length and same elements. + return true; +} + +// TODO: replace std::vector with gsl::span. +std::tuple IsTwoAxisFusion( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether two consecutive axes are fused. + // - size_t: the axis in destination shape formed by fusing two source axes. + // - size_t: the first axis fused. + // - size_t: the length of fusion. In two-axis fusion considered by this + // function, the length of fusion is always 2. + const size_t src_rank = src_shape.size(); + const size_t dst_rank = dst_shape.size(); + if (src_rank < 2 || dst_rank < 1) { + return std::make_tuple(false, -1, -1, -1); + } + if (src_rank - 1 != dst_rank) { + return std::make_tuple(false, -1, -1, -1); + } + for (size_t i_src = 0; i_src < src_rank; ++i_src) { + if (i_src + 1 > src_rank - 1) { + // We are at src_shape[i] and we need + // src_shape[i + 1] to fuse. + // If we are at the last axis, we cannot fuse. + break; + } + const int64_t prod = src_shape[i_src] * src_shape[i_src + 1]; + + for (size_t i_dst = 0; i_dst < dst_rank; ++i_dst) { + // Check if shape[i_src:i_src+2] (i.e., shape[i_src] and shape[i_src+1]) + // for source tensor are fused into shape[i_dst] for destination tensor. + if (prod != dst_shape[i_dst]) { + continue; + } + // Check if corresponding dimensions before fusion area + // are the same. + const bool prefix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[0:i_src]. + 0, i_src, + // Represent dst_shape[0:i_dst]. + 0, i_dst); + const bool suffix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[i_src+2:]. + i_src + 2, src_rank, + // Represent dst_shape[i_dst+1:]. + i_dst + 1, dst_rank); + if (prefix_shape_match && suffix_shape_match) { + return std::make_tuple( + true, i_dst, i_src, 2); + } + } + } + return std::make_tuple(false, 0, 0, 0); +} + +std::tuple IsTwoAxisDecomposition( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether one source axis is decomposed into two consecutive destination axes. + // - size_t: the axis in source shape decomposed into two consecutive destination axes. + // - size_t: the first axis the source axis decomposed into. + // - size_t: the number of decomposed axes. It's always 2 in this function. + return IsTwoAxisFusion(dst_shape, src_shape); +} + +std::vector RepeatVector(const std::vector& vec, int64_t repeat) { + std::vector new_vec; + for (int64_t i = 0; i < repeat; ++i) { + new_vec.insert(new_vec.end(), vec.begin(), vec.end()); + } + return new_vec; +} + +DeviceMesh CreateInterleaveDeviceMesh( + const DeviceMesh& source_mesh, const int64_t repeat) { + // Given a 1-D device mesh [0, 1] and repeat=2, + // return 1-D device mesh [0, 1, 0, 1]. + if (source_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source mesh shape 1-D."); + } + + // Mesh to return. + DeviceMesh new_mesh; + + std::vector& elements = new_mesh.device_mesh_elements; + for (int64_t i = 0; i < repeat; ++i) { + elements.insert( + elements.end(), + source_mesh.device_mesh_elements.begin(), + source_mesh.device_mesh_elements.end()); + } + + // source mesh must be 1-D so we only care its 1st dimension. + new_mesh.device_mesh_shape.push_back(source_mesh.device_mesh_shape[0] * repeat); + + return new_mesh; +} + +std::tuple ComputeNativeSpecForTwoAxisFusion( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t fused_axis_in_src, + const int64_t fusion_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example: RS[0], shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1, 0, 1] + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + ORT_ENFORCE(src_spec.CountShardingAxes() == 1, "Tensor to be reshaped has too many sharding axes."); + ORT_ENFORCE(src_spec.device_mesh.device_mesh_shape.size() == 1, "Source device mesh be 1-D."); + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src)) { + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example 1: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 0, 0, 0, 0], (device assignment) + // [1, 1, 1, 1, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[ 0, 1, 2, 3, 4, 5, 6, 7]] + // - Device 1's local tensor (shape: [2, 4]). + // [[ 8, 9, 10, 11, 12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from shape [2, 4] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 2: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 4, 5], + // [ 6, 7]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 3: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [1, 1], + // [1, 1], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 8, 9], + // [10, 11]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 4, 5], + // [ 6, 7], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor. + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1, 0, 1] + + // Reuse original device mesh but shard the fusion axis in output tensor. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), src_spec.device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src + 1)) { + // Example 1 of determining native output sharding spec: + // - logical input shape: [3, 4] + // - logical output shape: [12] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 1, 0, 1], (device assignment) + // [0, 1, 0, 1], + // [0, 1, 0, 1]] + // [[0, 1, 2, 3], (values) + // [4, 5, 6, 7], + // [8, 9, 10, 11]], + // - Device 0's local tensor. + // [[0, 0], + // [0, 0], + // [0, 0]] + // [[0, 2], + // [4, 6], + // [8, 10]], + // - Device 1's local tensor. + // [[1, 1], + // [1, 1], + // [1, 1]] + // [[1, 3], + // [5, 7], + // [9, 11]], + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [6] by fusing both axes in shape [3, 2]. + // 3. Run local reshape (reshape from [3, 2] to [6]): + // - Device 0's local output tensor. + // [0, 0, 0, 0, 0, 0] + // [0, 2, 4, 6, 8, 10] + // - Device 1's local output tensor. + // [1, 1, 1, 1, 1, 1] + // [1, 3, 5, 7, 9, 11] + // 4. Determine native output sharding spec by comparing local output tensors and logical tensor. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 2 of determining native output sharding spec: + // - logical input shape: [3, 8] + // - logical output shape: [24] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1], + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15], + // [16, 17, 18, 19, 20, 21, 22, 23]] + // - Device 0's local tensor (shape: [3, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13], + // [16, 17, 20, 21]] + // - Device 1's local tensor (shape: [3, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15], + // [18, 19, 22, 23]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [12] by fusing both axes in shape [3, 4]. + // 3. Run local reshape (reshape from [3, 4] to [12]): + // - Device 0's local output tensor . + // [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21] + // - Device 1's local output tensor . + // [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = . + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 3: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 4, 5, 8, 9, 12, 13] + // - Device 1's local output tensor . + // [ 2, 3, 6, 7, 10, 11, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 2 + // + // Example 4: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 1, 1, 1, 1], (device assignment) + // [0, 0, 0, 0, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 2, 3], + // [ 8, 9, 10, 11]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 4, 5, 6, 7], + // [12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor . + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1] = [0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1] * (first fused dimension) = [0, 1] * 2 = [0, 1, 0, 1] + + // The output device mesh is the repeats of the original device. + // Let's use Python syntax. If the original device mesh is [0, 1, 0, 1], and + // the first fused dimension is 3, then the output device mesh is [0, 1, 0, 1] * 3. + auto dst_device_mesh = DeviceMesh::Create1D( + src_spec.device_mesh.device_mesh_elements, + src_shape[fused_axis_in_src]); + // Sharding happens in the fusion axis with the new device mesh. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), dst_device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && (src_spec.GetPartitionAxis() < fused_axis_in_src || src_spec.GetPartitionAxis() > fused_axis_in_src + 1)) { + // It's two-axis fusion but the fused axes is not sharded. + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + auto dst_spec = TensorPartitionSpec::CreateByDropOneAxis( + src_spec, fused_axis_in_src + 1); + return std::make_tuple(true, dst_spec); + } else { + return std::make_tuple(false, TensorPartitionSpec()); + } +} + +// Arguments: +// - device_elements: a vector of device IDs. +// It should only contain unique device IDs or +// repeats of a list of unique device IDs. Otherwise, +// (0, 0) is returned. +// Returns: +// - count per device ID (all device IDs should have the same count) +// - number of unique device IDs +// Examples: +// - [0, 1] -> (2, 1) +// - [0, 1, 2, 0, 1, 2] -> (2, 3) +std::tuple ComputeRepeatAndRepeatStride( + const std::vector& device_elements) { + int64_t first_device_id = device_elements.at(0); + int64_t first_device_id_count = 0; + for (size_t i = 0; i < device_elements.size(); ++i) { + if (device_elements.at(i) == first_device_id) { + ++first_device_id_count; + } + } + size_t repeat_stride = device_elements.size() / first_device_id_count; + + // Check if the device mesh pattern is supported. + // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1]. + // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0]. + for (size_t repeat = 0; repeat < first_device_id_count; ++repeat) { + for (size_t device_id = 0; device_id < repeat_stride; ++device_id) { + ORT_ENFORCE( + device_elements.at(repeat * repeat_stride + device_id) == device_elements.at(device_id), + "Unsupported device mesh pattern."); + } + } + + // If device_mesh=[0, 1, 2, 0, 1, 2], returns (2, 3), which means + // - each device repeats twice for "2" in (2, 3). + // - there are 3 unique devices for "3" in (2, 3). + return std::make_tuple(first_device_id_count, repeat_stride); +} + +std::tuple ComputeNativeSpecForTwoAxisDecomposition( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t decomposed_axis_in_src, + const int64_t decomposition_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0], shape=[8], device_mesh=[0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1] -> RS[0] + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> RS[0] + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RS[0]RR + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RRS[0]R + if (src_spec.CountShardingAxes() != 1) { + throw std::runtime_error("Too many sharding axes."); + } + if (src_spec.device_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source device mesh be 1-D."); + } + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.OnlyShardAxis(decomposed_axis_in_src)) { + const int64_t device_stride = src_shape[decomposed_axis_in_src] / src_spec.device_mesh.device_mesh_shape[0]; + if (device_stride >= dst_shape[decomposition_axis_in_dst + 1] && device_stride % dst_shape[decomposition_axis_in_dst + 1] == 0) { + // Since 2nd decomposition dimension is a factor of device stride, + // Sharding happens at 1st decomposition axis in dst. + // device_stride = 10 + // S[0], shape=[20], device=[0, 1] -> S[0]R, shape=[2, 10], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> RS[0], shape=[1, 16], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1] + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + // Sharding spec is copied if the axis is not decomposed. + // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2] + // The spec for "5" is copied. + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // The spec for "5" is copied and "1" is replica. + // This reshape only adds a dummy new axis without affecting + // the underlying sharding status. + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + } else { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + // Now, we know sharding happens at decomposed_axis_in_src axis in destination tensor. + // - effective_device_stride along decomposed_axis_in_src: device_stride / dst_shape[decomposed_axis_in_src + 1] + // - The original device patterns repeats: dst_shape[decomposed_axis_in_src] / effective_device_stride times. + const int64_t effective_device_stride = device_stride / dst_shape[decomposed_axis_in_src + 1]; + // How many times a device ID changes along decomposed_axis_in_src axis in destination tensor. + const int64_t number_of_device_changes = dst_shape[decomposed_axis_in_src] / effective_device_stride; + if ((size_t)number_of_device_changes != src_spec.device_mesh.device_mesh_elements.size()) { + throw std::runtime_error("Not supported. Resharding is required."); + } + auto dst_device_mesh = CreateInterleaveDeviceMesh( + src_spec.device_mesh, 1); + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else if (dst_shape[decomposition_axis_in_dst + 1] > device_stride && dst_shape[decomposition_axis_in_dst + 1] % device_stride == 0) { + // Since 2nd decomposition dimension is a multiple of device stride, + // sharding happens at 2nd decomposition axis in dst. + // stride = 4 + // S[0], shape=[8], device=[0, 1] -> S[0]R, shape=[4, 2], device=[0, 1] + // + // stride = 8 + // S[0], shape=[32], device=[0, 1, 0, 1] -> RS[0], shape=[2, 16], device=[0, 1] + std::vector dst_axis_specs; + // How many times a device ID appears. + // E.g., [0, 1, 0, 1, 0, 1] -> 3 + int64_t repeats = 0; + // Number of unique devices. + // E.g., [0, 1, 0, 1, 0, 1] -> 2 + int64_t repeat_stride = 0; + DeviceMesh dst_device_mesh; + std::tie(repeats, repeat_stride) = ComputeRepeatAndRepeatStride(src_spec.device_mesh.device_mesh_elements); + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (dst_shape[decomposition_axis_in_dst + 1] == 1) { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats == 1 && dst_shape[decomposition_axis_in_dst + 1] == device_stride * repeat_stride) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats != 1 && dst_shape[decomposition_axis_in_dst + 1] % (device_stride * repeat_stride) == 0) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + // Extract [0, 1] from [0, 1, 0, 1]. + std::vector unique_device_mesh_elements( + src_spec.device_mesh.device_mesh_elements.begin(), + src_spec.device_mesh.device_mesh_elements.begin() + repeat_stride); + // Compute new repeats. + // Example of repeats change from 2 to 1: + // [16]-shape tensor [2, 8]-shape tensor + // with 1-D device mesh -> Reshape -> with 1-D device mesh + // [0, 1, 0, 1] (repeats=2) [0, 1] (repeats=1) + const int64_t new_repeat = dst_shape[decomposition_axis_in_dst + 1] / (device_stride * repeat_stride); + dst_device_mesh.device_mesh_shape.push_back(repeat_stride); + dst_device_mesh.device_mesh_elements = RepeatVector(unique_device_mesh_elements, new_repeat); + } else { + throw std::runtime_error("Not supported. Resharding is required."); + } + } + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else { + // Not supported. Resharding is required. + return std::make_tuple(false, TensorPartitionSpec()); + } + } else { + // Source tensor is sharded on non-decomposed axis. + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else { + // R -> RR + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, src_spec.device_mesh)); + } +} + +// Arguments: +// global_data_shape: logical shape of Reshape's 1st input. +// global_shape_span: logical content of Reshape's 2nd input. +// Returns: +// logical shape of Reshape's output. +inline TensorShape InferDistributedReshapeLogicalOutputShape( + const TensorShape& global_data_shape, + const gsl::span& global_shape_span, + const int64_t allow_zero) { + return onnxruntime::cuda::InferReshapeOutputShape( + global_data_shape, + global_shape_span, + allow_zero); +} + +template +DistributedReshape::DistributedReshape(const OpKernelInfo& info) : DistributedKernel(info) { + allow_zero_ = info.GetAttrOrDefault("allowzero", static_cast(0)); +} + +template +Status DistributedReshape::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + auto data_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& data_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + if (data_sharding_spec.HasNoShard() && shape_sharding_spec.HasNoShard() && output_sharding_spec.HasNoShard()) { + // Case: all inputs and outputs are not sharded. + const auto target_shape = onnxruntime::cuda::InferReshapeOutputShape( + data_tensor, + shape_tensor, + allow_zero_); + + auto output_tensor = context->Output(0, target_shape); + + // Copy data from input from output. + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "Shape tensor should not be sharded because it will trigger communication. " + "If sharding shape is needed, please request this feature on Github."); + ORT_ENFORCE(shape_tensor->Shape().NumDimensions() == 1, "Shape must be a 1-D tensor."); + const auto original_data_shape = ComputeOriginShape(data_tensor->Shape(), data_sharding_spec); + const auto original_output_shape = InferDistributedReshapeLogicalOutputShape( + original_data_shape, + shape_tensor->template DataAsSpan(), + allow_zero_); + + // TODO: remove below code after replacing std::vector with TensorShape in other APIs. + std::vector src_shape(original_data_shape.GetDims().begin(), original_data_shape.GetDims().end()); + std::vector dst_shape(original_output_shape.GetDims().begin(), original_output_shape.GetDims().end()); + + // Case: Two axis fusion + bool is_two_axis_fusion = false; + size_t two_axis_fusion_axis_in_dst = 0; + size_t two_axis_fusion_first_fused_axis_in_src = 0; + size_t two_axis_fusion_fused_axis_count = 0; + std::tie( + is_two_axis_fusion, + two_axis_fusion_axis_in_dst, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_fused_axis_count) = IsTwoAxisFusion(src_shape, dst_shape); + + if (is_two_axis_fusion) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisFusion( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + + // Case: Two axis decomposition + bool is_two_axis_decomposition = false; + size_t two_axis_decomposition_decomposed_axis_in_src = 0; + size_t two_axis_decomposition_first_factor_axis_in_dst = 0; + size_t two_axis_decomposition_factor_axis_count_in_dst = 0; + std::tie( + is_two_axis_decomposition, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst, + two_axis_decomposition_factor_axis_count_in_dst) = IsTwoAxisDecomposition(src_shape, dst_shape); + + if (is_two_axis_decomposition) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisDecomposition( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + } + + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h new file mode 100644 index 000000000000..e251c3cdc38d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/tensor/reshape.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReshape final : public DistributedKernel { + public: + explicit DistributedReshape(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t allow_zero_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index dfd5f589355d..b6b509023a1a 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -30,7 +30,7 @@ void GatherTensor( const Tensor* tensor, Tensor* gathered) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); FuncAllGather( nccl_kernel, @@ -51,7 +51,7 @@ std::unique_ptr GatherTensor( const TensorPartitionSpec& spec, const Tensor* tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape gathered_shape(tensor->Shape()); gathered_shape[shard_axis] *= shard_count; @@ -82,7 +82,7 @@ void ShardTensor( const Tensor* tensor, Tensor* shard_tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape shard_shape = ComputeShardShape( tensor->Shape(), shard_axis, @@ -118,7 +118,7 @@ std::unique_ptr ShardTensor( TensorShape shard_shape = ComputeShardShape( tensor->Shape(), spec.GetPartitionAxis(), - spec.GetPartitionCount(spec.GetPartitionAxis())); + spec.GetUniqueDeviceCount(spec.GetPartitionAxis())); auto shard_buffer = Tensor::Create(tensor->DataType(), shard_shape, alloc); // Shard with pre-allocated buffer. diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index 220938f3ceae..20c936e1b671 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -129,7 +129,7 @@ TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorParti } TensorShape shape(shard_shape); const int64_t axis = spec.GetPartitionAxis(); - shape[axis] *= spec.GetPartitionCount(axis); + shape[axis] *= spec.GetUniqueDeviceCount(axis); return shape; } @@ -140,7 +140,15 @@ TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpe return shard_shape; } const int64_t axis = spec.GetPartitionAxis(); - shard_shape[axis] /= spec.GetPartitionCount(axis); + const int64_t unique_device_count = spec.GetUniqueDeviceCount(axis); + ORT_ENFORCE(shard_shape[axis] % unique_device_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + // If a [8, 16]-tensor is shared by device mesh [0, 1, 0, 1] along axis=1 (2nd axis), + // the local tensors on device 0 & 1 have same shape [8, 8 (from 16/2)] instead of + // [8, 4 (from 16/4)]. The reason is that + // - First, the original tensor are split into 4 sub-tensors [8, 4] along the 2nd axis. + // - The 1st and 3rd sub-tensors are concatenated along axis=1 to one tensor on device 0. + // - The 2nd and 4th sub-tensors are concatenated along axis=1 to one tensor on device 1. + shard_shape[axis] /= unique_device_count; return shard_shape; } @@ -202,7 +210,7 @@ bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec) { if (axis < 0 || gsl::narrow(axis) >= shape.NumDimensions()) { return false; } - if (shape[axis] % spec.GetPartitionCount(axis) != 0) { + if (shape[axis] % spec.GetDeviceCount(axis) != 0) { return false; } return true; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 6bdf5699c268..5185c41e6888 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -76,6 +76,43 @@ class DeviceMesh { void Print() const { std::cout << ToString() << std::endl; } + + static DeviceMesh Create1D(std::vector device_mesh_elements, size_t repeats = 1) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape.push_back(device_mesh_elements.size() * repeats); + for (size_t i = 0; i < repeats; ++i) { + device_mesh.device_mesh_elements.insert( + device_mesh.device_mesh_elements.end(), + device_mesh_elements.begin(), + device_mesh_elements.end()); + } + return device_mesh; + } + + // If the two meshes have the same shape and elements, return true. + // Otherwise, return false. + bool operator==(const DeviceMesh& other) const { + if (device_mesh_shape.size() != other.device_mesh_shape.size() || + device_mesh_elements.size() != other.device_mesh_elements.size()) { + return false; + } + + for (size_t i = 0; i < device_mesh_elements.size(); ++i) { + if (device_mesh_elements.at(i) != other.device_mesh_elements.at(i)) { + return false; + } + } + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (device_mesh_shape.at(i) != other.device_mesh_shape.at(i)) { + return false; + } + } + return true; + } + + bool operator!=(const DeviceMesh& other) const { + return !(*this == other); + } }; class AxisPartitionSpec { @@ -114,6 +151,10 @@ class AxisPartitionSpec { return AxisPartitionSpec(Condition::Shard, device_mesh_axis); } + static AxisPartitionSpec CreateCopy(const AxisPartitionSpec& spec) { + return AxisPartitionSpec(spec.cond, spec.device_mesh_axis); + } + // A normal ctor. // TODO(wechi): Consider to hide it and revise the `public` members/functions // exposed to the user. @@ -132,6 +173,14 @@ class AxisPartitionSpec { void Print() const { std::cout << ToString() << std::endl; } + + bool operator==(const AxisPartitionSpec& other) const { + return cond == other.cond && device_mesh_axis == other.device_mesh_axis; + } + + bool operator!=(const AxisPartitionSpec& other) const { + return !(*this == other); + } }; // Return true if `axis` is a valid axis index for a tensor of rank `rank`. @@ -193,6 +242,32 @@ class TensorPartitionSpec { // const TensorPartitionSpec& spec, int64_t new_shard_axis) { // } + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const size_t rank, const DeviceMesh& device_mesh) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateOneTensorAxisOneDeviceMeshAxisSharding( + const size_t rank, const DeviceMesh& device_mesh, const size_t tensor_axis, const size_t device_mesh_axis) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + axis_specs[tensor_axis] = AxisPartitionSpec::CreateShard(device_mesh_axis); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateByDropOneAxis( + const TensorPartitionSpec& TensorPartitionSpec, const size_t axis_to_drop) { + std::vector axis_specs; + for (size_t i = 0; i < TensorPartitionSpec.axis_specs.size(); ++i) { + if (i != axis_to_drop) { + axis_specs.push_back(TensorPartitionSpec.axis_specs[i]); + } + } + return TensorPartitionSpec::Create(axis_specs, TensorPartitionSpec.device_mesh); + } + // Helper to debug and generate error message; e.g., // "TensorPartitionSpec{RS[0], Device Mesh: DeviceMesh{Shape: [4,], Elements: [0,1,2,3,]}}". std::string ToString() const { @@ -303,7 +378,7 @@ class TensorPartitionSpec { // Return the number of shards along the first sharded tensor axis. // This value matches the number of devices along the associated mesh axis. // Return 1 if there is no sharding. - int64_t GetPartitionCount(int64_t axis) const { + int64_t GetDeviceCount(int64_t axis) const { ValidateAxisIndex(axis, Rank()); auto axis_spec = GetAxisSpec(axis); if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { @@ -312,6 +387,37 @@ class TensorPartitionSpec { return device_mesh.device_mesh_shape.at(axis_spec.device_mesh_axis); } } + + // Similar to GetDeviceCount(), but returns the number of unique devices + // along the first sharded tensor axis. + int64_t GetUniqueDeviceCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + std::set device_ids( + device_mesh.device_mesh_elements.begin(), + device_mesh.device_mesh_elements.end()); + return device_ids.size(); + } + } + + bool operator==(const TensorPartitionSpec& other) const { + if (axis_specs.size() != other.axis_specs.size()) { + return false; + } + for (size_t i = 0; i < axis_specs.size(); ++i) { + if (!(axis_specs.at(i) == other.axis_specs.at(i))) { + return false; + } + } + return device_mesh == other.device_mesh; + } + + bool operator!=(const TensorPartitionSpec& other) const { + return !(*this == other); + } }; // Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index e762a80cb0e2..29ca8124bfd0 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -165,6 +165,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); #endif template <> @@ -334,6 +338,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 97befe2a5830..8082b8c010e9 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -191,6 +191,51 @@ void RegisterCollectiveOps() { .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReshape) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr( + "allowzero", + "(Optional) By default, when any value in the 'shape' input is equal to zero " + "the corresponding dimension value is copied from the input tensor dynamically. " + "allowzero=1 indicates that if any value in the 'shape' input is set to zero, " + "the zero value is honored, similar to NumPy.", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "Specified shape for output.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index f8b92eface52..e3106e41e77c 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -176,17 +176,17 @@ class CudaKernel : public OpKernel { return provider_->ComputeStream(); } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); + return gpu_data_transfer->CopyTensorAsync(src, dst, stream); + } + protected: template inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { return provider_->template GetConstOnes(count, stream); } - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 3c6d900cee9a..ab364c274a32 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -6,6 +6,81 @@ namespace onnxruntime { namespace cuda { +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, // Data tensor's shape. + const gsl::span& shape_span, // Shape that data tensor reshape to. + bool allow_zero) { + TensorShapeVector shape_vector(shape_span.begin(), shape_span.end()); + ReshapeHelper helper(data_tensor_shape, shape_vector, allow_zero); + return TensorShape(shape_vector); +} + +TensorShape InferReshapeOutputShape(const Tensor* src, const Tensor* shape, bool allow_zero) { + ORT_ENFORCE(shape != nullptr, "Cannot reshape to a null shape."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "Shape must be an 1-D tensor."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape must be on CPU."); + + return InferReshapeOutputShape( + src->Shape(), + shape->template DataAsSpan(), + allow_zero); +} + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y) { + if (!X) return Status(common::ONNXRUNTIME, common::FAIL, "Missing data tensor to be reshaped."); + if (!shape) return Status(common::ONNXRUNTIME, common::FAIL, "Missing shape tensor for reshaping."); + if (shape->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + } + if (shape->Location().device.Type() != OrtDevice::CPU) { + return Status(common::ONNXRUNTIME, common::FAIL, "Shape tensor must be on CPU."); + } + + const void* src_data = X->DataRaw(); + void* dst_data = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (src_data != dst_data) { + ORT_ENFORCE(ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(cuda_kernel->CopyTensor(*X, *Y, *ctx->GetComputeStream())); + } + + return Status::OK(); +} + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero) { + // TODO(wechi): Study if Tensor can be created as view to existing tensor. + // This feature can refine code for re-sharding and shape broadcasting. + + ORT_ENFORCE(X != nullptr, "Missing data tensor to be reshaped."); + ORT_ENFORCE(shape != nullptr, "Missing shape tensor for reshaping."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape tensor must be on CPU."); + + // Calculate output's shape. + auto dst_shape = InferReshapeOutputShape(X, shape, allow_zero); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto Y = Tensor::Create(X->DataType(), dst_shape, alloc); + + // Do reshape. It's equivalent to memcpy. + ORT_ENFORCE(FuncReshape(cuda_kernel, ctx, X, shape, allow_zero, Y.get()).IsOK()); + return Y; +} + ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.h b/onnxruntime/core/providers/cuda/tensor/reshape.h index 01e933e65888..8f33265071ed 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.h +++ b/onnxruntime/core/providers/cuda/tensor/reshape.h @@ -10,6 +10,39 @@ namespace onnxruntime { namespace cuda { +// Deduce output shape from ONNX Reshape's inputs. +// +// Arguments: +// data_tensor_shape: The shape of the data tensor (i.e., 1st input). +// shape_span: Elements in the shape tensor (i.e., 2nd input). +// +// Returns: +// The output shape of this Reshape. No symbolic values such as "-1" or "0". +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, + const gsl::span& shape_span, + bool allow_zero); + +TensorShape InferReshapeOutputShape( + const Tensor* src, + const Tensor* shape, + bool allow_zero); + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y); + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero); + class Reshape final : public CudaKernel { public: Reshape(const OpKernelInfo& info) : CudaKernel(info), @@ -18,27 +51,11 @@ class Reshape final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { // Copy the second input tensor into the shape vector - const Tensor* shapeTensor = context->Input(1); - if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - if (shapeTensor->Shape().NumDimensions() != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); - auto data_span = shapeTensor->template DataAsSpan(); - TensorShapeVector shape(data_span.begin(), data_span.end()); - const Tensor* X = context->Input(0); - if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - const TensorShape& X_shape = X->Shape(); - - ReshapeHelper helper(X_shape, shape, allow_zero_); - - Tensor* Y = context->Output(0, TensorShape(shape)); - const void* source = X->DataRaw(); - void* target = Y->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_ENFORCE(context->GetComputeStream()); - ORT_RETURN_IF_ERROR(CopyTensor(*X, *Y, *context->GetComputeStream())); - } - - return Status::OK(); + const Tensor* data_tensor = context->Input(0); + const Tensor* shape_tensor = context->Input(1); + const auto target_shape = InferReshapeOutputShape(data_tensor, shape_tensor, allow_zero_); + Tensor* output_tensor = context->Output(0, target_shape); + return FuncReshape(this, context, data_tensor, shape_tensor, allow_zero_, output_tensor); } private: diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 463c1cf0d2ea..c0b7d4722d3e 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -173,17 +173,17 @@ class RocmKernel : public OpKernel { return provider_->PerThreadDefaultMiopenHandle(); } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); + return gpu_data_transfer->CopyTensorAsync(src, dst, stream); + } + protected: template inline const T* GetConstOnes(size_t count, hipStream_t stream) const { return provider_->template GetConstOnes(count, stream); } - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index a9b55122c680..2acca4a8f22a 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import unittest +from typing import Tuple import numpy as np import onnxscript @@ -18,6 +19,672 @@ def shard_tensor(X, rank, axis, num_shards): return np.split(X, num_shards, axis)[rank] +def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): + if axis is None: + return X + shards = np.split(X, len(device_mesh), axis) + selected_shards = tuple(shard for device_id, shard in zip(device_mesh, shards) if device_id == rank) + return np.concatenate(selected_shards, axis=axis) + + +def translate_device_mesh_to_attrs(device_mesh: np.ndarray): + device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" + device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" + return device_mesh_shape, device_mesh_elements + + +def parse_sharding_spec(spec: str): + axis_conditions = [] + sharding_device_axes = [] + token_index = 0 + while True: + token = spec[token_index] + if token == "R": + axis_conditions.append("R") + sharding_device_axes.append(None) + token_index += 1 + elif token == "S": + axis_conditions.append("S") + # Move token pointer to "["" + token_index += 1 + assert spec[token_index] == "[" + number_tokens = "" + while True: + token_index += 1 + token = spec[token_index] + if token == "]": + break + number_tokens += token + assert spec[token_index] == "]" + # Skip "]" and point to next S/R token + token_index += 1 + sharding_device_axes.append(int(number_tokens)) + else: + raise ValueError(f"Invalid spec: {spec}") + if token_index >= len(spec): + break + return axis_conditions, sharding_device_axes + + +def find_shard_axis(axis_conditions, shard_device_axes): + sharded_axis = None + sharded_axis_count = 0 + for i, cond in enumerate(axis_conditions): + if cond == "S": + sharded_axis = i + sharded_axis_count += 1 + assert sharded_axis_count in (0, 1), "Can shard at most one axis per tensor." + if sharded_axis is not None: + assert shard_device_axes[sharded_axis] == 0, "Device mesh must be 1-D, so 0 is the only valid device mesh axis." + return sharded_axis + + +def shard_tensor_per_spec(tensor: np.ndarray, rank: int, spec: str, device_mesh: np.ndarray): + axis_conditions, shard_device_axes = parse_sharding_spec(spec) + sharded_axis = find_shard_axis(axis_conditions, shard_device_axes) + return shard_tensor_per_device_mesh(tensor, rank, sharded_axis, list(device_mesh.flat)) + + +class TestDistributedReshape(unittest.TestCase): + def _check_distributed_reshape( + self, + shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + input_device_meshs: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshs: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) + assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) + assert len(input_device_meshs) == len(input_shard_specs) + assert len(output_device_meshs) == len(output_shard_specs) + + input_device_mesh_shapes = [] + input_device_mesh_elements = [] + for device_mesh in input_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + input_device_mesh_shapes.append(device_mesh_shape) + input_device_mesh_elements.append(device_mesh_element) + + output_device_mesh_shapes = [] + output_device_mesh_elements = [] + for device_mesh in output_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + output_device_mesh_shapes.append(device_mesh_shape) + output_device_mesh_elements.append(device_mesh_element) + + @onnxscript.script() + def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): + return MICROSOFT_OPSET.DistributedReshape( + data_tensor, + shape_tensor, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + shape_tensor = np.array( + target_shape, + dtype=np.int64, + ) + + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + assert "S" not in input_shard_specs[1], "Shape should not be sharded." + + expected = np.reshape(data_tensor, shape_tensor) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + + onnx_model = distributed_reshape_instance.to_model_proto( + input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data_tensor, + "shape_tensor": shape_tensor, + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): + # Two axis fusion. + # S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + ), + target_shape=(6,), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): + # Two axis fusion. + # RS[0], shape=[2, 4], device_mesh=[0, 1] -> S[0], shape = [8], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=( + 2, + 4, + ), + target_shape=(8,), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): + # Two axis fusion. + # S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + 5, + ), + target_shape=( + 2, + 15, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): + # Two axis fusion. + # RS[0]R, shape=[2, 4, 5], device_mesh=[0, 1] -> RS[0], shape = [2, 20], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 4, + 5, + ), + target_shape=( + 2, + 20, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): + # Two axis fusion. + # RRS[0], shape=[2, 3, 6], device_mesh=[0, 1] -> RS[0], shape = [2, 18], device_mesh=[0, 1, 0, 1, 0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + 6, + ), + target_shape=( + 2, + 18, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + # Two axis fusion. + # RRS[0], shape=[2, 3, 8], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1] * 3 + + # Two axis fusion. + # RS[0]R, shape=[2, 8, 3], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1] + + def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): + # Two axis decomposition + # S[0], shape=[6], device_mesh=[0, 1] -> S[0]R, shape=[2, 3], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(6,), + target_shape=( + 2, + 3, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 1, + 16, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[2, 8], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 2, + 8, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 4, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 8, + 2, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 16, + 1, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1, 0, 1] + + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 1, + 16, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): + # Two axis decomposition + # repeats=2 8 = repeats * [unique IDs] + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[2, 8], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 2, + 8, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 4, + 4, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 8, + 2, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 16, + 1, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01(self): + # Two axis decomposition + # [21, 4096] -> [3, 7, 4096] + # data: (21, 2048), (RS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRS, [0, 1]) + self._check_distributed_reshape( + shape=( + 21, + 4096, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rrsr_01(self): + # Two axis decomposition + # [3, 7, 4096] -> [3, 7, 64, 64] + # data: (3, 7, 2048), (RRS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRSR, [0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 3, + 7, + 64, + 64, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]R",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self): + # Two axis fusion + # [3, 7, 4096] -> [21, 4096] + # data: (3, 7, 4096), (RRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RR, [0, 1]) + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 21, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self): + # Two axis fusion + # [21, 4096] -> [3, 7, 4096] + # data: (21, 4096), (RR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRR, [0, 1]) + self._check_distributed_reshape( + shape=( + 21, + 4096, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRR",), + ) + + def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_010101(self): + # Two axis fusion + # [3, 64, 7, 64] -> [192, 7, 64] + # data: (3, 32, 7, 64), (RSRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 64, + 7, + 64, + ), + target_shape=( + 192, + 7, + 64, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("S[0]RR",), + ) + + def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_rsrr_01(self): + # Two axis decomposition + # [192, 7, 7] -> [3, 64, 7, 7] + # data: (96, 7, 7), (SRR, [0, 1, 0, 1, 0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RSRR, [0.0, 1.0]) + + self._check_distributed_reshape( + shape=( + 192, + 7, + 7, + ), + target_shape=( + 3, + 64, + 7, + 7, + ), + input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101(self): + # Two axis fusion + # [3, 64, 7, 7] -> [192, 7, 7] + # data: (3, 32, 7, 7), (RSRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 64, + 7, + 7, + ), + target_shape=( + 192, + 7, + 7, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("S[0]RR",), + ) + + def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_64_rsrr_01(self): + # Two axis decomposition + # [192, 7, 64] -> [3, 64, 7, 64] + # data: (96, 7, 64), (SRR, [0, 1, 0, 1, 0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RSRR, [0.0, 1.0]) + + self._check_distributed_reshape( + shape=( + 192, + 7, + 64, + ), + target_shape=( + 3, + 64, + 7, + 64, + ), + input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(self): + # Two axis fusion + # [3, 7, 64, 64] -> [3, 7, 4096] + # data: (3, 7, 32, 64), (RRSR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRS, [0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 7, + 64, + 64, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self): + # Two axis fusion + # [3, 7, 4096] -> [21, 4096] + # data: (3, 7, 2048), (RRS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RS, [0, 1]) + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 21, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. From c10b83eb68e64d66a836b63c6976ecfdd58257bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 27 Oct 2023 12:06:38 +0200 Subject: [PATCH 26/36] Update python cryptography version to 41.0.4 (#18056) ### Description Version 41.0.0 currently used has vulnerabilities. ### Motivation and Context See [Vulnerable OpenSSL included in cryptography wheels](https://github.com/advisories/GHSA-v8gr-m533-ghj9) --- .../github/linux/docker/migraphx-ci-pipeline-env.Dockerfile | 2 +- tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 8a67692ae598..7fa606b6c294 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -66,7 +66,7 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 rm ~/miniconda.sh && conda clean -ya # Conda base patch -RUN pip install cryptography==41.0.0 +RUN pip install cryptography==41.0.4 # Create migraphx-ci environment ENV CONDA_ENVIRONMENT_PATH /opt/miniconda/envs/migraphx-ci diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 32bb99f08812..412bc00d0277 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -68,7 +68,7 @@ RUN conda create -y -n ${CONDA_DEFAULT_ENV} python=3.9 ENV PATH ${CONDA_ENVIRONMENT_PATH}/bin:${PATH} # Conda base patch -RUN pip install cryptography==41.0.0 +RUN pip install cryptography==41.0.4 # Enable rocm-ci environment SHELL ["conda", "run", "-n", "rocm-ci", "/bin/bash", "-c"] From b5f242e9789b2add38578ec922c2b7f8cab254a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 27 Oct 2023 14:33:55 +0200 Subject: [PATCH 27/36] GemmFloat8 as a contrib ops (#16051) ### Description Add support for Gemm with float 8 as a contrib op. --------- Co-authored-by: Randy Shuai Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Scott McKay Co-authored-by: Xavier Dupre --- cmake/onnxruntime_rocm_hipify.cmake | 3 + docs/ContribOperators.md | 66 +++ docs/OperatorKernels.md | 1 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../contrib_ops/cuda/math/gemm_float8.cc | 70 +++ .../contrib_ops/cuda/math/gemm_float8.cu | 402 ++++++++++++++++++ .../contrib_ops/cuda/math/gemm_float8.h | 65 +++ .../core/graph/contrib_ops/contrib_defs.cc | 118 +++++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../core/providers/cuda/cuda_common.cc | 85 ++++ onnxruntime/core/providers/cuda/cuda_common.h | 36 ++ .../python/tools/symbolic_shape_infer.py | 4 + onnxruntime/test/onnx/main.cc | 1 + .../python/onnxruntime_test_float8_gemm8.py | 284 +++++++++++++ 14 files changed, 1139 insertions(+) create mode 100644 onnxruntime/contrib_ops/cuda/math/gemm_float8.cc create mode 100644 onnxruntime/contrib_ops/cuda/math/gemm_float8.cu create mode 100644 onnxruntime/contrib_ops/cuda/math/gemm_float8.h create mode 100644 onnxruntime/test/python/onnxruntime_test_float8_gemm8.py diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 4ef0584b0273..ec021a1550d6 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,6 +48,9 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" + "math/gemm_float8.cc" + "math/gemm_float8.cu" + "math/gemm_float8.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1a76c18a6a8e..890403556cc4 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -40,6 +40,7 @@ Do not modify directly.* * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu + * com.microsoft.GemmFloat8 * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm @@ -2137,6 +2138,71 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmFloat8** + + Generic Gemm for float and float 8. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : string
+
Activation function, RELU or GELU or NONE (default).
+
alpha : float
+
Scalar multiplier for the product of input tensors A * B.
+
beta : float
+
Scalar multiplier for the product of input bias C.
+
dtype : int
+
Output Type. Same definition as attribute 'to' for operator Cast.
+
transA : int
+
Whether A should be transposed. Float 8 only supprted transA=0.
+
transB : int
+
Whether B should be transposed. Float 8 only supprted transB=1.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+
B : TB
+
Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+
C (optional) : TC
+
Input tensor C.
+
scaleA (optional) : TS
+
Scale of tensor A if A is float 8 tensor
+
scaleB (optional) : TS
+
Scale of tensor B if B is float 8 tensor
+
scaleY (optional) : TS
+
Scale of the output tensor if A or B is float 8.
+
+ +#### Outputs + +
+
Y : TR
+
Output tensor of shape (M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input A.
+
TB : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input B.
+
TC : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input C.
+
TR : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to result type.
+
TS : tensor(float)
+
Constrain type for all input scales (scaleA, scaleB, scaleY).
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d047096cb8c8..bfb7716dc5ce 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -831,6 +831,7 @@ Do not modify directly.* |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 29ca8124bfd0..e6a216795c10 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -144,6 +144,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -317,6 +318,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc new file mode 100644 index 000000000000..251850f62136 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/math/gemm.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cuda/math/gemm_float8.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", BuildKernelDefConstraints()) \ + .TypeConstraint("TB", BuildKernelDefConstraints()) \ + .TypeConstraint("TR", BuildKernelDefConstraints()) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ + GemmFloat8); + +REGISTER_KERNEL() + +GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto& device_prop = GetDeviceProp(); + sm_count_ = device_prop.multiProcessorCount; + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + +#if (CUDA_VERSION <= 12000) + ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); +#endif + + std::string stemp = info.GetAttrOrDefault("activation", "NONE"); + if (stemp == "NONE") { + epilogue_ = CUBLASLT_EPILOGUE_DEFAULT; + } else if (stemp == "RELU") { + epilogue_ = CUBLASLT_EPILOGUE_RELU; + } else if (stemp == "GELU") { + epilogue_ = CUBLASLT_EPILOGUE_GELU; + } else { + ORT_THROW("Unexpected value for activation: '", stemp, "'."); + } +} + +Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const { + GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({})); + if (!helper.State().IsOK()) + return helper.State(); + + M = gsl::narrow_cast(helper.M()); + N = gsl::narrow_cast(helper.N()); + K = gsl::narrow_cast(helper.K()); + return helper.State(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu new file mode 100644 index 000000000000..df25342342cd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -0,0 +1,402 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// The operator calls function 'cublasLtMatmul' +// (https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul). +// It lets the function checks what configuration is valid or not. If not, the error message +// shows the error message 'CUBLAS_STATUS_NOT_SUPPORTED'. NVIDIA documentation provides +// information on what attribute or type must be modified. +// This operator requires CUDA_VERSION >= 11.8 for float 8 and CUDA_VERSION >= 12.0 +// for beta != 0. + +#include +#include +#include +#include "contrib_ops/cuda/math/gemm_float8.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// It must exist somewhere already. +int32_t TypeSize(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return 4; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return 2; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return 1; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + +void GemmFloat8::SetParams(const TensorShape& a_shape, const TensorShape& b_shape, + int& M, int& N, int& K, int& lda, int& ldb, int& ldd) const { + int m_idx = transA_ ? 1 : 0; + int k_idx = 1 - m_idx; + int n_idx = transB_ ? 0 : 1; + + M = static_cast(a_shape[m_idx]); + K = static_cast(a_shape[k_idx]); + N = static_cast(b_shape[n_idx]); + lda = static_cast(a_shape[1]); + ldb = static_cast(b_shape[1]); + ldd = static_cast(b_shape[n_idx]); +} + +template +int32_t GetTypeAndShape(const TValue* input, + TensorShape& shape, + bool swap = false) { + shape = input->Shape(); + ORT_ENFORCE(shape.NumDimensions() == 2); + if (swap) { + std::swap(shape[0], shape[1]); + } + return input->GetElementType(); +} + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input_A = nullptr; + const Tensor* input_B = nullptr; + const Tensor* input_C = nullptr; + const Tensor* scale_A = nullptr; + const Tensor* scale_B = nullptr; + const Tensor* scale_Y = nullptr; + bool has_scales = false; + bool has_bias = false; + int n_inputs = ctx->InputCount(); + + input_A = ctx->Input(0); + input_B = ctx->Input(1); + if (n_inputs == 3) { + input_C = ctx->Input(2); + has_bias = true; + } else if (n_inputs > 3) { + ORT_ENFORCE(n_inputs >= 5, "Unexpected number of inputs=", n_inputs, "."); + has_scales = true; + scale_A = ctx->Input(3); + scale_B = ctx->Input(4); + scale_Y = n_inputs < 6 ? nullptr : ctx->Input(5); + ORT_ENFORCE(scale_A->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_B->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_Y == nullptr || scale_Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + if (ctx->Input(2) != nullptr) { + input_C = ctx->Input(2); + has_bias = true; + ORT_ENFORCE(input_C->GetElementType() == dtype_, "Bias type must be equal to dtype."); + } + } + + auto first_type = input_A->GetElementType(); + bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + if (!is_float8) + return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); + return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); +} + +Status GemmFloat8::ComputeRowMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_A, dtype_B, dtype_C, + dtype_Y, shape_A, shape_B, shape_C, shape_Y, transA_, transB_, + input_A->DataRaw(), input_B->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), M, N, K, lda, ldb, ldd, true); +} + +Status GemmFloat8::ComputeColMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + std::swap(shape_A[0], shape_A[1]); + std::swap(shape_B[0], shape_B[1]); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C, true) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_B, dtype_A, dtype_C, + dtype_Y, shape_B, shape_A, shape_C, shape_Y, transB_, transA_, + input_B->DataRaw(), input_A->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), N, M, K, ldb, lda, ldd, false); +} + +Status GemmFloat8::ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const { + cudaStream_t stream = Stream(ctx); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + cublasLtHandle_t cublasLt; + CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt)); + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, + Ddesc = nullptr; + + // Create matrix descriptors. Not setting any extra attributes. + cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A); + cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B); + cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y); + cudaDataType_t scale_cuda_type = + onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C); + + cublasComputeType_t compute_type; + switch (d_cuda_type) { + case CUDA_R_16F: + switch (a_cuda_type) { + case CUDA_R_8F_E4M3: + case CUDA_R_8F_E5M2: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + break; + } + break; + case CUDA_R_16BF: + compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; + break; + case CUDA_R_32F: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + ORT_THROW("Unable to determine computeType in operator GemmFloat8."); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb)); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd)); + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + CUBLAS_RETURN_IF_ERROR( + cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type)); + cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); + + if (sm_count_ != 0) { + int math_sm_count = static_cast(sm_count_); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, + sizeof(math_sm_count))); + } + + if (has_scales) { + // gemm float 8 + const int8_t ifast_accumulation_mode = 1; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &ifast_accumulation_mode, sizeof(ifast_accumulation_mode))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a, + sizeof(p_scale_a))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b, + sizeof(p_scale_b))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, + sizeof(p_scale_b))); + + // float 8 +#if CUDA_VERSION >= 11080 + if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { + // For FP8 output, cuBLAS requires C_type to be same as bias_type + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type, + sizeof(bias_cuda_type))); + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } +#else + // An output is still needed but it is not initialized. + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); +#endif + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue_, sizeof(epilogue_)); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t + // The workspace should be allocated once from OpKernelContext assuming + // only one cuda function is running at a time (which is not necessarily true + // with H100). + size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb + cublasLtMatmulPreference_t preference = nullptr; + cublasLtMatmulPreferenceCreate(&preference); + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, sizeof(workspaceSize)); + + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResults = 0; + cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic( + cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults); + ORT_ENFORCE( + returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to find any suitable algorithm due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, + ", alpha=", alpha_, ", beta=", beta_, ", n_inputs=", n_inputs, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", shape_C=", (shape_C.NumDimensions() > 0 ? shape_C[0] : 0), "x", + (shape_C.NumDimensions() > 1 ? shape_C[1] : 0), ", M=", M, ", N=", N, ", K=", K, + ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, + ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". Check NVIDIA documentation to see what combination is valid: ", + "https://docs.nvidia.com/cuda/cublas/" + "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" + "cublasltmatmulalgogetheuristic."); + + void* workspace = nullptr; + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize)); + } + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul + const void* bias = has_bias ? p_input_c : p_output_y; + cuda_status = cublasLtMatmul( + cublasLt, operationDesc, static_cast(&alpha_), /* alpha */ + p_input_a, /* A */ + Adesc, p_input_b, /* B */ + Bdesc, static_cast(&beta_), /* beta */ + bias, /* C */ + Cdesc, p_output_y, /* Y */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream); /* stream */ + ORT_ENFORCE( + cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to run cublasLtMatmul due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, ", alpha=", alpha_, + ", n_inputs=", n_inputs, ", A_type=", + onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, + ", ldd=", ldd, ", workspaceSize=", workspaceSize, + ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaFree(workspace)); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt)); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.h b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h new file mode 100644 index 000000000000..e84ccd55b200 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cublas_v2.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Calls https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul. +// D = alpha*(A*B) +class GemmFloat8 final : public onnxruntime::cuda::CudaKernel { + public: + GemmFloat8(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + void SetParams(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K, + int& lda, int& ldb, int& ldd) const; + Status SetCheck(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K) const; + + Status ComputeRowMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + Status ComputeColMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + + Status ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_b, + int32_t dtype_c, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool transa, bool transb, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const; + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t sm_count_; + int64_t dtype_; + cublasLtEpilogue_t epilogue_; + + // TODO(xadupre): add epilogue (= activation function, Relu or Gelu are available). +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 681a728f823d..e757e39130d3 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2573,6 +2573,124 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_TYPES \ + { "tensor(float8e4m3fn)", "tensor(float8e5m2)", "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#else +#define GEMM_FLOAT8_TYPES \ + { "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#endif + +ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, + OpSchema() + .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC") + .Attr( + "transA", + "Whether A should be transposed. Float 8 only supprted transA=0.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "transB", + "Whether B should be transposed. Float 8 only supprted transB=1.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "alpha", + "Scalar multiplier for the product of input tensors A * B.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "beta", + "Scalar multiplier for the product of input bias C.", + AttributeProto::FLOAT, + 0.0f) + .Attr( + "dtype", + "Output Type. Same definition as attribute 'to' for operator Cast.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "activation", + "Activation function, RELU or GELU or NONE (default).", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Input( + 0, + "A", + "Input tensor A. " + "The shape of A should be (M, K) if transA is 0, " + "or (K, M) if transA is non-zero.", + "TA") + .Input( + 1, + "B", + "Input tensor B. " + "The shape of B should be (K, N) if transB is 0, " + "or (N, K) if transB is non-zero.", + "TB") + .Input( + 2, + "C", + "Input tensor C.", + "TC", + OpSchema::Optional) + .Input( + 3, + "scaleA", + "Scale of tensor A if A is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 4, + "scaleB", + "Scale of tensor B if B is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 5, + "scaleY", + "Scale of the output tensor if A or B is float 8.", + "TS", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (M, N).", "TR") + .TypeConstraint( + "TA", + GEMM_FLOAT8_TYPES, + "Constrain type to input A.") + .TypeConstraint( + "TB", + GEMM_FLOAT8_TYPES, + "Constrain type to input B.") + .TypeConstraint( + "TC", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain type to input C.") + .TypeConstraint( + "TR", + GEMM_FLOAT8_TYPES, + "Constrain type to result type.") + .TypeConstraint("TS", {"tensor(float)"}, + "Constrain type for all input scales (scaleA, scaleB, scaleY).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto transAAttr = ctx.getAttribute("transA"); + bool transA = transAAttr ? static_cast(transAAttr->i()) != 0 : false; + auto transBAttr = ctx.getAttribute("transB"); + bool transB = transBAttr ? static_cast(transBAttr->i()) != 0 : false; + auto& first_input_shape = getInputShape(ctx, 0); + auto& second_input_shape = getInputShape(ctx, 1); + if (first_input_shape.dim_size() != 2) { + fail_shape_inference("First input does not have rank 2"); + } + if (second_input_shape.dim_size() != 2) { + fail_shape_inference("Second input does not have rank 2"); + } + updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); + })); + static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, int64_t N) { diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index afaa380d6ac7..aa31f3b5a7c6 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -112,6 +112,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFloat8); class OpSet_Microsoft_ver1 { public: @@ -218,6 +219,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 57477f167c55..288ca8e97e34 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -27,5 +27,90 @@ const HalfGemmOptions* HalfGemmOptions::GetInstance() { return &instance; } +const char* cublasGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + default: + return ""; + } +} + +const char* CudaDataTypeToString(cudaDataType_t dt) { + switch (dt) { + case CUDA_R_16F: + return "CUDA_R_16F"; + case CUDA_R_16BF: + return "CUDA_R_16BF"; + case CUDA_R_32F: + return "CUDA_R_32F"; +#if (CUDA_VERSION >= 11080) + case CUDA_R_8F_E4M3: + return "CUDA_R_8F_E4M3"; + case CUDA_R_8F_E5M2: + return "CUDA_R_8F_E5M2"; +#endif + default: + return ""; + } +} + +const char* CublasComputeTypeToString(cublasComputeType_t ct) { + switch (ct) { + case CUBLAS_COMPUTE_16F: + return "CUBLAS_COMPUTE_16F"; + case CUBLAS_COMPUTE_32F: + return "CUBLAS_COMPUTE_32F"; + case CUBLAS_COMPUTE_32F_FAST_16F: + return "CUBLAS_COMPUTE_32F_FAST_16F"; + case CUBLAS_COMPUTE_32F_FAST_16BF: + return "CUBLAS_COMPUTE_32F_FAST_16BF"; + case CUBLAS_COMPUTE_32F_FAST_TF32: + return "CUBLAS_COMPUTE_32F_FAST_TF32"; + case CUBLAS_COMPUTE_64F: + return "CUBLAS_COMPUTE_64F"; + default: + return ""; + } +} + +// It must exist somewhere already. +cudaDataType_t ToCudaDataType(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return CUDA_R_32F; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return CUDA_R_16F; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return CUDA_R_16BF; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + return CUDA_R_8F_E4M3; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return CUDA_R_8F_E5M2; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index fa258961f115..9cd4e721ccab 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -11,6 +11,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/common/status.h" +#include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" @@ -48,6 +49,33 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef BFloat16 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E4M3FN MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E5M2 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) @@ -152,5 +180,13 @@ class HalfGemmOptions { static HalfGemmOptions instance; }; +const char* cublasGetErrorEnum(cublasStatus_t error); + +const char* CudaDataTypeToString(cudaDataType_t dt); + +const char* CublasComputeTypeToString(cublasComputeType_t ct); + +cudaDataType_t ToCudaDataType(int32_t element_type); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 272727a9f537..ef1c46b83946 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -198,6 +198,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, "Gelu": self._infer_Gelu, "GemmFastGelu": self._infer_GemmFastGelu, + "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, @@ -2317,6 +2318,9 @@ def _infer_QuickGelu(self, node): # noqa: N802 def _infer_GemmFastGelu(self, node): # noqa: N802 self._compute_matmul_shape(node) + def _infer_GemmFloat8(self, node): # noqa: N802 + self._compute_matmul_shape(node) + def _infer_LayerNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node) if len(node.output) > 1: diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index de5431ca4a46..0526ccca5bb4 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -761,6 +761,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_TSTR("sce_none_weights_expanded")}; std::unordered_set> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests)); + if (enable_cuda) { all_disabled_tests.insert(std::begin(cuda_flaky_tests), std::end(cuda_flaky_tests)); } diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py new file mode 100644 index 000000000000..784ae8ce70bd --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# pylint: disable=C0116,W0212,R1720,C0103,C0114 +# +# Note: the precision is different on V100, H100 even with the same code. +# The thresholds were adjusted on H100 as the precision seems lower on this machine. + +import itertools +import unittest +import warnings + +import numpy as np +import parameterized +from numpy.testing import assert_allclose +from onnx import TensorProto +from onnx.checker import check_model +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info +from onnx.numpy_helper import from_array + +from onnxruntime import InferenceSession + + +class TestFloat8Gemm8(unittest.TestCase): + def get_model_gemm( + self, + float_name, + alpha=1.0, + beta=0.0, + transA=0, + transB=0, + domain="", + dtype=TensorProto.FLOAT, + activation="NONE", + ): + proto_type = getattr(TensorProto, float_name) + use_f8 = proto_type in (TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2) + + a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) + b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) + d = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) + + inits = [] + kwargs = {} + node_inputs = ["Af", "Bf"] + inputs = [a, b] + bias = beta != 0 + if bias: + inputs.append(make_tensor_value_info("C", TensorProto.FLOAT, [None, None])) + node_inputs = ["Af", "Bf", "Cf"] + if use_f8: + node_inputs.extends(["one"] * 3) + elif use_f8: + node_inputs.append("") + node_inputs.extend(["one"] * 3) + + if use_f8: + assert domain == "com.microsoft" + inits.append(from_array(np.array([1], dtype=np.float32), name="one")) + kwargs = dict( + domain=domain, + dtype=dtype, + ) + if activation is not None: + kwargs["activation"] = activation + op_name = "GemmFloat8" + elif domain == "com.microsoft": + op_name = "GemmFloat8" + kwargs = dict( + domain=domain, + dtype=dtype, + ) + else: + op_name = "Gemm" + nodes = [ + make_node("Cast", ["A"], ["Af"], to=proto_type), + make_node("Cast", ["B"], ["Bf"], to=proto_type), + make_node("Cast", ["C"], ["Cf"], to=proto_type) if bias else None, + make_node( + op_name, + node_inputs, + ["Yf"], + transA=transA, + transB=transB, + alpha=alpha, + beta=beta, + **kwargs, + ), + make_node("Cast", ["Yf"], ["Y"], to=TensorProto.FLOAT), + ] + nodes = [n for n in nodes if n is not None] + graph = make_graph(nodes, "gemm", inputs, [d], inits) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9) + if domain != "com.microsoft": + check_model(onnx_model) + return onnx_model + + def common_test_model_gemm(self, float_type, mul=0.33, atol=0, rtol=0, square=True, **kwargs): + if square: + a = (np.arange(256) * 0.01).astype(np.float32).reshape((-1, 16)) + b = (np.arange(256) * -0.01).astype(np.float32).reshape((-1, 16)) + c = (np.arange(256) * 0.03).astype(np.float32).reshape((-1, 16)) + b[:, 0] += 1 + else: + a = (np.arange(256) / 256).astype(np.float32).reshape((32, -1)) + b = (np.arange(512) / 512).astype(np.float32).reshape((32, -1)) + c = (np.arange(128) / 128).astype(np.float32).reshape((8, 16)) + + feeds = {"A": a, "B": b} + + expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b) + expected *= kwargs.get("alpha", 1.0) + if kwargs.get("beta", 0) != 0: + expected += kwargs["beta"] * c + feeds["C"] = c + + onnx_model = self.get_model_gemm("FLOAT", **kwargs) + + ref = InferenceSession( + onnx_model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + y = ref.run(None, feeds)[0] + if float_type in ("FLOAT", "FLOAT16"): + try: + assert_allclose(expected, y, atol=atol, rtol=rtol) + except Exception as e: + + def check(f): + try: + return f()[:2, :2] + except Exception as e: + return str(e) + + raise AssertionError( + f"Gemm ERROR len(inputs)={len(feeds)}" + f"\na@b=\n{check(lambda:a@b)}" + f"\na.T@b=\n{check(lambda:a.T@b)}" + f"\na@b.T=\n{check(lambda:a@b.T)}" + f"\na.T@b.T=\n{check(lambda:a.T@b.T)}" + f"\n----\nb@a=\n{check(lambda:b@a)}" + f"\nb.T@a=\n{check(lambda:b.T@a)}" + f"\nb@a.T=\n{check(lambda:b@a.T)}" + f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}" + f"\n----\nexpected=\n{expected[:2,:2]}" + f"\n----\ngot=\n{y[:2,:2]}" + f"\nkwargs={kwargs}" + ) from e + + self.assertEqual(expected.shape, y.shape) + self.assertEqual(expected.dtype, y.dtype) + + onnx_model_f8 = self.get_model_gemm(float_type, domain="com.microsoft", **kwargs) + try: + ref8 = InferenceSession( + onnx_model_f8.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + except Exception as e: + if "CUDA < 12.0 does not support bias" in str(e): + return + raise AssertionError(f"Could not load model {onnx_model_f8}") from e + try: + y = ref8.run(None, feeds)[0] + except Exception as e: + if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e): + # Skipping. This machine does not support float8. + warnings.warn("unable to test with float8 on this machine.") + return + raise AssertionError(f"Could not execute model {onnx_model_f8}") from e + try: + assert_allclose(expected, y, atol=atol, rtol=rtol) + except Exception as e: + + def check(f): + try: + return f()[:2, :2] + except Exception as e: + return str(e) + + raise AssertionError( + f"Gemm ERROR len(inputs)={len(feeds)}" + f"\na@b=\n{check(lambda:a@b)}" + f"\na.T@b=\n{check(lambda:a.T@b)}" + f"\na@b.T=\n{check(lambda:a@b.T)}" + f"\na.T@b.T=\n{check(lambda:a.T@b.T)}" + f"\n----\nb@a=\n{check(lambda:b@a)}" + f"\nb.T@a=\n{check(lambda:b.T@a)}" + f"\nb@a.T=\n{check(lambda:b@a.T)}" + f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}" + f"\n----\nexpected=\n{expected[:2,:2]}" + f"\n----\ngot=\n{y[:2,:2]}" + f"\nkwargs={kwargs}" + ) from e + self.assertEqual(expected.shape, y.shape) + self.assertEqual(expected.dtype, y.dtype) + + def test_model_gemm_float(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3) + + def test_model_gemm_float_default_values(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None) + + def test_model_gemm_float_relu(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU") + + def test_model_gemm_float_gelu(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU") + + def test_model_gemm_float_bias(self): + self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3) + + def test_model_gemm_float16(self): + self.common_test_model_gemm( + "FLOAT16", + rtol=1e-2, + dtype=TensorProto.FLOAT16, + transB=1, + ) + + def test_model_gemm_float8_e4m3(self): + self.common_test_model_gemm( + "FLOAT8E4M3FN", + rtol=0.5, + dtype=TensorProto.FLOAT, + transA=0, + transB=1, + alpha=10.0, + ) + + @parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1]))) + def test_combinations_square_matrices(self, transA, transB): + self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3) + + @parameterized.parameterized.expand( + [ + ((2, 3), (3, 5), 0, 0), + ((2, 3), (5, 3), 0, 1), + ((2, 3), (5, 2), 1, 1), + ((2, 3), (2, 5), 1, 0), + ] + ) + def test_combinations(self, shapeA, shapeB, transA, transB): + model = make_model( + make_graph( + [ + make_node( + "GemmFloat8", + ["A", "B"], + ["Y"], + transA=transA, + transB=transB, + domain="com.microsoft", + ) + ], + "f8", + [ + make_tensor_value_info("A", TensorProto.FLOAT, [None, None]), + make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), + ], + [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])], + ) + ) + + sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + a = np.arange(np.prod(shapeA)).reshape(shapeA).astype(np.float32) + b = np.arange(np.prod(shapeB)).reshape(shapeB).astype(np.float32) + try: + expected = (a.T if transA else a) @ (b.T if transB else b) + except Exception as e: + raise AssertionError( + f"Unable to multiply shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}" + ) from e + try: + got = sess.run(None, {"A": a, "B": b}) + except Exception as e: + raise AssertionError( + f"Unable to run Gemm with shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}" + ) from e + self.assertEqual(expected.shape, got[0].shape) + self.assertEqual(expected.dtype, got[0].dtype) + assert_allclose(expected, got[0]) + + +if __name__ == "__main__": + # TestFloat8Gemm8().test_model_gemm_float() + unittest.main(verbosity=2) From 58f1d15d19006464546c73ac6fbed95ff5c90b0a Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Fri, 27 Oct 2023 23:50:18 +0800 Subject: [PATCH 28/36] Replace Transpose with Replace if they are equivalent (#18096) ### Description Transpose is equivalent to a Reshape if: empty dimensions can change place, not empty dimensions must be in the same order in the permuted tenosr. Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). This pr adds a graph transformer which replaces Transpose with Reshape if they are equivalent. Because Transpose need memory copy while Reshape needn't, this replacement can save overhead for memory copy. --- .../testdata/transform/transpose_graph_gen.py | 41 +++++++++++ .../transpose_to_reshape_invalid.onnx | Bin 0 -> 235 bytes .../transform/transpose_to_reshape_valid.onnx | Bin 0 -> 235 bytes .../core/optimizer/graph_transformer_utils.cc | 2 + .../core/optimizer/transpose_replacement..cc | 68 ++++++++++++++++++ .../core/optimizer/transpose_replacement.h | 38 ++++++++++ .../test/optimizer/graph_transform_test.cc | 41 +++++++++++ 7 files changed, 190 insertions(+) create mode 100644 onnxruntime/test/testdata/transform/transpose_graph_gen.py create mode 100644 onnxruntime/test/testdata/transform/transpose_to_reshape_invalid.onnx create mode 100644 onnxruntime/test/testdata/transform/transpose_to_reshape_valid.onnx create mode 100644 orttraining/orttraining/core/optimizer/transpose_replacement..cc create mode 100644 orttraining/orttraining/core/optimizer/transpose_replacement.h diff --git a/onnxruntime/test/testdata/transform/transpose_graph_gen.py b/onnxruntime/test/testdata/transform/transpose_graph_gen.py new file mode 100644 index 000000000000..14f2994a1925 --- /dev/null +++ b/onnxruntime/test/testdata/transform/transpose_graph_gen.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import onnx +from onnx import TensorProto, helper + + +def GenerateModel(model_name, valid): # noqa: N802 + nodes = [ + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[2, 1, 3, 0]), + helper.make_node("Add", ["transposed_input_0", "input_1"], ["output"]), + ] + + if valid: + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [1, 1, 3, 3]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 1, 3, 1]), + ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 1, 3, 1])] + else: + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [1, 2, 3, 3]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 3, 1]), + ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 2, 3, 1])] + + graph = helper.make_graph( + nodes, + "TransposeAndAdd", # name + inputs, + outputs, + [], + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +GenerateModel("transpose_to_reshape_valid.onnx", True) +GenerateModel("transpose_to_reshape_invalid.onnx", False) diff --git a/onnxruntime/test/testdata/transform/transpose_to_reshape_invalid.onnx b/onnxruntime/test/testdata/transform/transpose_to_reshape_invalid.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a09b13fc184a8fbf2f494fd89fd79483157e0940 GIT binary patch literal 235 zcmdCbAZ1`WNr4M$3oaE-Oaj6Hml!l} literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index e5c65b2a96d8..57d76577f1ba 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -63,6 +63,7 @@ #include "orttraining/core/optimizer/scaled_sum_fusion.h" #include "orttraining/core/optimizer/shape_optimizer.h" #include "orttraining/core/optimizer/transformer_layer_recompute.h" +#include "orttraining/core/optimizer/transpose_replacement.h" #include "core/optimizer/compute_optimizer/upstream_gather.h" #include "core/optimizer/compute_optimizer/upstream_reshape.h" #include "core/optimizer/pre_shape_node_elimination.h" @@ -203,6 +204,7 @@ std::vector> GeneratePreTrainingTransformers( std::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), compatible_eps); ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); + ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); } break; case TransformerLevel::Level3: { diff --git a/orttraining/orttraining/core/optimizer/transpose_replacement..cc b/orttraining/orttraining/core/optimizer/transpose_replacement..cc new file mode 100644 index 000000000000..48e9c4d6e6a0 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/transpose_replacement..cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/transpose_replacement.h" + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status TransposeReplacement::Apply(Graph& graph, + Node& transpose_node, + RewriteRuleEffect& rule_effect, + const logging::Logger& logger) const { + auto& transpose_inputs = transpose_node.MutableInputDefs(); + auto& transpose_outputs = transpose_node.MutableOutputDefs(); + NodeArg* input = transpose_inputs[0]; + auto input_shape = input->Shape(); + if (!input_shape) { + LOG_DEBUG_INFO(logger, "Exit TransposeReplacement optimization for input shape is None."); + return Status::OK(); + } + auto perm = graph_utils::onnx_repeated_values::RetrieveValues(transpose_node.GetAttributes().at("perm")); + InlinedVector new_shape; + new_shape.reserve(perm.size()); + int64_t last_permuted_axis = 0; + for (int i = 0; i < static_cast(perm.size()); ++i) { + if (!input_shape->dim(static_cast(perm[i])).has_dim_value()) { + LOG_DEBUG_INFO(logger, "Exit TransposeReplacement optimization for not supporting symbolic shape."); + return Status::OK(); + } + new_shape.push_back(input_shape->dim(static_cast(perm[i])).dim_value()); + if (input_shape->dim(static_cast(perm[i])).dim_value() == 1) + continue; + if (perm[i] < last_permuted_axis) { + LOG_DEBUG_INFO(logger, "Exit TransposeReplacement optimization for not supporting shape."); + return Status::OK(); + } + last_permuted_axis = perm[i]; + } + + transpose_inputs.push_back( + optimizer::compute_optimizer::CreateInitializerFromVector(graph, + {static_cast(new_shape.size())}, + new_shape, + graph.GenerateNodeArgName("transpose_reshape_shape"))); + + Node& transpose_reshape_node = graph.AddNode(graph.GenerateNodeName("Transpose_Reshape"), + "Reshape", + "Transpose replaced Reshape", + transpose_inputs, + transpose_outputs, + nullptr, + kOnnxDomain); + transpose_reshape_node.SetExecutionProviderType(transpose_node.GetExecutionProviderType()); + graph_utils::FinalizeNodeFusion(graph, transpose_reshape_node, transpose_node); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} + +bool TransposeReplacement::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/transpose_replacement.h b/orttraining/orttraining/core/optimizer/transpose_replacement.h new file mode 100644 index 000000000000..c38e40233982 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/transpose_replacement.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/compute_optimizer/shared_utils.h" + +namespace onnxruntime { + +/** +@Class TransposeReplacement + +Transpose is equivalent to a Reshape if: + empty dimensions (which dim_value=1) can change place, not empty dimensions must be in + the same order in the permuted tenosr. + Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + +This Rewrite rule replaces Transpose which meets the requirments with Reshape. +Because Transpose need memory copy while Reshape needn't, this replacement can save overhead for memory copy. + +It is attempted to be triggered only on nodes with op type "Transpose". +*/ +class TransposeReplacement : public RewriteRule { + public: + TransposeReplacement() noexcept : RewriteRule("TransposeReplacement") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Transpose"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 94ca87b2ac51..20b9354d8574 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -18,6 +18,7 @@ #include "orttraining/core/optimizer/concat_replacement.h" #include "orttraining/core/optimizer/batchnorm_replacement.h" #include "orttraining/core/optimizer/localized_recompute.h" +#include "orttraining/core/optimizer/transpose_replacement.h" #include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" @@ -551,6 +552,46 @@ TEST_F(GraphTransformationTests, ConcatReplacement) { ASSERT_EQ(op_to_count["com.microsoft.ConcatTraining"], 1); } +TEST_F(GraphTransformationTests, TransposeReplacement) { + { + auto model_uri = MODEL_FOLDER "transpose_to_reshape_valid.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = std::make_unique("TransposeReplacement"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 0); + ASSERT_EQ(op_to_count["Reshape"], 1); + } + + { + auto model_uri = MODEL_FOLDER "transpose_to_reshape_invalid.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = std::make_unique("TransposeReplacement"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 1); + ASSERT_EQ(op_to_count["Reshape"], 0); + } +} + TEST_F(GraphTransformationTests, MegatronMLPPartitionRank0) { auto model_uri = MODEL_FOLDER "model_parallel/mlp_megatron_basic_test.onnx"; std::shared_ptr p_model; From 2eeafc37bca21dc8bf337dda7020b486543162d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Fri, 27 Oct 2023 18:23:19 +0200 Subject: [PATCH 29/36] Enable global TRT timing cache (#17865) I am adding a new `trt_timing_cache_path` option. Internally it is handled as `global_cache_path_` and will be set via a fall through approach: 1. no path provided => workdir 2. `trt_engine_cache_path` provided but no `trt_timing_cache_path` => `trt_engine_cache_path` 3. `trt_timing_cache_path` provided => `trt_timing_cache_path` (if not provided `trt_engine_cache_path` will still be workdir) ### Motivation and Context A TRT timing cache can be reused across multiple models as it only holds kernel timings and it is common that network "patterns" are reused. This can accelerate build times a lot. --------- Co-authored-by: Carson M --- .../tensorrt/tensorrt_provider_options.h | 3 +- .../tensorrt/tensorrt_execution_provider.cc | 29 +- .../tensorrt/tensorrt_execution_provider.h | 5 +- .../tensorrt_execution_provider_info.cc | 195 ++------- .../tensorrt_execution_provider_info.h | 1 + .../tensorrt/tensorrt_provider_factory.cc | 1 + .../core/session/provider_bridge_ort.cc | 1 + .../python/onnxruntime_pybind_state.cc | 9 +- onnxruntime/test/perftest/ort_test_session.cc | 373 +++--------------- .../providers/tensorrt/tensorrt_basic_test.cc | 1 + 10 files changed, 139 insertions(+), 479 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 8f2b5af87050..680ce1cc5b9a 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -25,13 +25,14 @@ struct OrtTensorRTProviderOptionsV2 { int trt_dla_core{0}; // DLA core number. Default 0 int trt_dump_subgraphs{0}; // dump TRT subgraph. Default 0 = false, nonzero = true int trt_engine_cache_enable{0}; // enable engine caching. Default 0 = false, nonzero = true - const char* trt_engine_cache_path{nullptr}; // specify engine cache path + const char* trt_engine_cache_path{nullptr}; // specify engine cache path, defaults to the working directory int trt_engine_decryption_enable{0}; // enable engine decryption. Default 0 = false, nonzero = true const char* trt_engine_decryption_lib_path{nullptr}; // specify engine decryption library path int trt_force_sequential_engine_build{0}; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true int trt_context_memory_sharing_enable{0}; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true int trt_layer_norm_fp32_fallback{0}; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true int trt_timing_cache_enable{0}; // enable TensorRT timing cache. Default 0 = false, nonzero = true + const char* trt_timing_cache_path{nullptr}; // specify timing cache path, if none is provided the trt_engine_cache_path is used int trt_force_timing_cache{0}; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true int trt_detailed_build_log{0}; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true int trt_build_heuristics_enable{0}; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index ef1f0bf9f8d0..a1fc67ff60b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -824,6 +824,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { cache_path_ = info.engine_cache_path; } + // use a more global cache if given + if (timing_cache_enable_) { + if (!info.timing_cache_path.empty()) { + global_cache_path_ = info.timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } engine_decryption_enable_ = info.engine_decryption_enable; if (engine_decryption_enable_) { engine_decryption_lib_path_ = info.engine_decryption_lib_path; @@ -928,6 +936,15 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path"; } } + if (timing_cache_enable_) { + std::string timing_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCachePath); + // use a more global cache if given + if (!timing_cache_path.empty()) { + global_cache_path_ = timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } const std::string engine_decryption_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDecryptionEnable); if (!engine_decryption_enable_env.empty()) { @@ -1019,6 +1036,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv throw std::runtime_error("Failed to create directory " + cache_path_); } } + if (!global_cache_path_.empty() && !fs::is_directory(global_cache_path_)) { + if (!fs::create_directory(global_cache_path_)) { + throw std::runtime_error("Failed to create directory " + global_cache_path_); + } + } { auto lock = GetApiLock(); runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); @@ -1104,6 +1126,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_dump_subgraphs: " << dump_subgraphs_ << ", trt_engine_cache_enable: " << engine_cache_enable_ << ", trt_cache_path: " << cache_path_ + << ", trt_global_cache_path: " << global_cache_path_ << ", trt_engine_decryption_enable: " << engine_decryption_enable_ << ", trt_engine_decryption_lib_path: " << engine_decryption_lib_path_ << ", trt_force_sequential_engine_build: " << force_sequential_engine_build_ @@ -2199,7 +2222,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectornode_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, - force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, + global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics}; *state = p.release(); return 0; @@ -2460,7 +2483,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector runtime_ = nullptr; OrtMutex tensorrt_mu_; int device_id_; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index cb7a568d0913..3ead33f9131d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -25,7 +25,7 @@ constexpr const char* kDLAEnable = "trt_dla_enable"; constexpr const char* kDLACore = "trt_dla_core"; constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; -constexpr const char* kCachePath = "trt_engine_cache_path"; +constexpr const char* kEngineCachePath = "trt_engine_cache_path"; constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; @@ -33,7 +33,8 @@ constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine constexpr const char* kContextMemorySharingEnable = "trt_context_memory_sharing_enable"; constexpr const char* kLayerNormFP32Fallback = "trt_layer_norm_fp32_fallback"; constexpr const char* kTimingCacheEnable = "trt_timing_cache_enable"; -constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache_match"; +constexpr const char* kTimingCachePath = "trt_timing_cache_path"; +constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache"; constexpr const char* kDetailedBuildLog = "trt_detailed_build_log"; constexpr const char* kBuildHeuristics = "trt_build_heuristics_enable"; constexpr const char* kSparsityEnable = "trt_sparsity_enable"; @@ -76,13 +77,14 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) - .AddAssignmentToReference(tensorrt::provider_option_names::kCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) @@ -115,7 +117,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.dla_core)}, {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, - {tensorrt::provider_option_names::kCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, + {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)}, @@ -123,6 +125,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)}, {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)}, {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)}, + {tensorrt::provider_option_names::kTimingCachePath, MakeStringWithClassicLocale(info.timing_cache_path)}, {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)}, {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)}, {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.build_heuristics_enable)}, @@ -142,7 +145,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensorRTProviderOptionsV2& info) { auto empty_if_null = [](const char* s) { return s != nullptr ? std::string{s} : std::string{}; }; const std::string kInt8CalibTable_ = empty_if_null(info.trt_int8_calibration_table_name); - const std::string kCachePath_ = empty_if_null(info.trt_engine_cache_path); + const std::string kEngineCachePath_ = empty_if_null(info.trt_engine_cache_path); + const std::string kTimingCachePath_ = empty_if_null(info.trt_timing_cache_path); const std::string kTacticSources_ = empty_if_null(info.trt_tactic_sources); const std::string kDecryptionLibPath_ = empty_if_null(info.trt_engine_decryption_lib_path); const std::string kExtraPluginLibPaths_ = empty_if_null(info.trt_extra_plugin_lib_paths); @@ -164,13 +168,14 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kDLACore, MakeStringWithClassicLocale(info.trt_dla_core)}, {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.trt_dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, - {tensorrt::provider_option_names::kCachePath, kCachePath_}, + {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, {tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.trt_context_memory_sharing_enable)}, {tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.trt_layer_norm_fp32_fallback)}, {tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.trt_timing_cache_enable)}, + {tensorrt::provider_option_names::kTimingCachePath, kTimingCachePath_}, {tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.trt_force_timing_cache)}, {tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.trt_detailed_build_log)}, {tensorrt::provider_option_names::kBuildHeuristics, MakeStringWithClassicLocale(info.trt_build_heuristics_enable)}, @@ -204,6 +209,27 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options if (provider_options == nullptr) { return; } + auto copy_string_if_needed = [&](std::string& s_in) { + if (string_copy) { + char* dest = nullptr; + auto str_size = s_in.size(); + if (str_size == 0) { + return (const char*)nullptr; + } else { + dest = new char[str_size + 1]; +#ifdef _MSC_VER + strncpy_s(dest, str_size + 1, s_in.c_str(), str_size); +#else + strncpy(dest, s_in.c_str(), str_size); +#endif + dest[str_size] = '\0'; + return (const char*)dest; + } + } else { + return s_in.c_str(); + } + }; + TensorrtExecutionProviderInfo internal_options = onnxruntime::TensorrtExecutionProviderInfo::FromProviderOptions(options); auto& trt_provider_options_v2 = *reinterpret_cast(provider_options); trt_provider_options_v2.device_id = internal_options.device_id; @@ -220,24 +246,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.int8_calibration_table_name.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_int8_calibration_table_name = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); -#else - strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_int8_calibration_table_name = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_int8_calibration_table_name = internal_options.int8_calibration_table_name.c_str(); - } + trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); trt_provider_options_v2.trt_int8_use_native_calibration_table = internal_options.int8_use_native_calibration_table; trt_provider_options_v2.trt_dla_enable = internal_options.dla_enable; @@ -245,45 +254,12 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs; trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.engine_cache_path.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_engine_cache_path = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.engine_cache_path.c_str(), str_size); -#else - strncpy(dest, internal_options.engine_cache_path.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_engine_cache_path = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_engine_cache_path = internal_options.engine_cache_path.c_str(); - } + trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); + trt_provider_options_v2.trt_timing_cache_path = copy_string_if_needed(internal_options.timing_cache_path); trt_provider_options_v2.trt_engine_decryption_enable = internal_options.engine_decryption_enable; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.engine_decryption_lib_path.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_engine_decryption_lib_path = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.engine_decryption_lib_path.c_str(), str_size); -#else - strncpy(dest, internal_options.engine_decryption_lib_path.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_engine_decryption_lib_path = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_engine_decryption_lib_path = internal_options.engine_decryption_lib_path.c_str(); - } + trt_provider_options_v2.trt_engine_decryption_lib_path = copy_string_if_needed(internal_options.engine_decryption_lib_path); trt_provider_options_v2.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build; trt_provider_options_v2.trt_context_memory_sharing_enable = internal_options.context_memory_sharing_enable; @@ -296,100 +272,11 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_builder_optimization_level = internal_options.builder_optimization_level; trt_provider_options_v2.trt_auxiliary_streams = internal_options.auxiliary_streams; - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.tactic_sources.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_tactic_sources = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.tactic_sources.c_str(), str_size); -#else - strncpy(dest, internal_options.tactic_sources.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_tactic_sources = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_tactic_sources = internal_options.tactic_sources.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.extra_plugin_lib_paths.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_extra_plugin_lib_paths = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.extra_plugin_lib_paths.c_str(), str_size); -#else - strncpy(dest, internal_options.extra_plugin_lib_paths.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_extra_plugin_lib_paths = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_extra_plugin_lib_paths = internal_options.extra_plugin_lib_paths.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.profile_min_shapes.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_profile_min_shapes = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.profile_min_shapes.c_str(), str_size); -#else - strncpy(dest, internal_options.profile_min_shapes.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_profile_min_shapes = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_profile_min_shapes = internal_options.profile_min_shapes.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.profile_max_shapes.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_profile_max_shapes = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.profile_max_shapes.c_str(), str_size); -#else - strncpy(dest, internal_options.profile_max_shapes.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_profile_max_shapes = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_profile_max_shapes = internal_options.profile_max_shapes.c_str(); - } - - if (string_copy) { - char* dest = nullptr; - auto str_size = internal_options.profile_opt_shapes.size(); - if (str_size == 0) { - trt_provider_options_v2.trt_profile_opt_shapes = nullptr; - } else { - dest = new char[str_size + 1]; -#ifdef _MSC_VER - strncpy_s(dest, str_size + 1, internal_options.profile_opt_shapes.c_str(), str_size); -#else - strncpy(dest, internal_options.profile_opt_shapes.c_str(), str_size); -#endif - dest[str_size] = '\0'; - trt_provider_options_v2.trt_profile_opt_shapes = (const char*)dest; - } - } else { - trt_provider_options_v2.trt_profile_opt_shapes = internal_options.profile_opt_shapes.c_str(); - } + trt_provider_options_v2.trt_tactic_sources = copy_string_if_needed(internal_options.tactic_sources); + trt_provider_options_v2.trt_extra_plugin_lib_paths = copy_string_if_needed(internal_options.extra_plugin_lib_paths); + trt_provider_options_v2.trt_profile_min_shapes = copy_string_if_needed(internal_options.profile_min_shapes); + trt_provider_options_v2.trt_profile_max_shapes = copy_string_if_needed(internal_options.profile_max_shapes); + trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes); trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 61a6bf08211b..b16543aa3d7d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -38,6 +38,7 @@ struct TensorrtExecutionProviderInfo { bool context_memory_sharing_enable{false}; bool layer_norm_fp32_fallback{false}; bool timing_cache_enable{false}; + std::string timing_cache_path{""}; bool force_timing_cache{false}; bool detailed_build_log{false}; bool build_heuristics_enable{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index d7e13df00027..426584553f34 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -103,6 +103,7 @@ struct Tensorrt_Provider : Provider { info.context_memory_sharing_enable = options.trt_context_memory_sharing_enable != 0; info.layer_norm_fp32_fallback = options.trt_layer_norm_fp32_fallback != 0; info.timing_cache_enable = options.trt_timing_cache_enable != 0; + info.timing_cache_path = options.trt_timing_cache_path == nullptr ? "" : options.trt_timing_cache_path; info.force_timing_cache = options.trt_force_timing_cache != 0; info.detailed_build_log = options.trt_detailed_build_log != 0; info.build_heuristics_enable = options.trt_build_heuristics_enable != 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d307f79c372e..9e5988347822 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1931,6 +1931,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor if (ptr != nullptr) { delete[] ptr->trt_int8_calibration_table_name; delete[] ptr->trt_engine_cache_path; + delete[] ptr->trt_timing_cache_path; delete[] ptr->trt_engine_decryption_lib_path; delete[] ptr->trt_tactic_sources; delete[] ptr->trt_extra_plugin_lib_paths; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 90271b545839..7faca3b4681b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -479,7 +479,7 @@ std::unique_ptr CreateExecutionProviderInstance( // So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance. // (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance // and TRT EP instance, so it won't be released.) - std::string calibration_table, cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; + std::string calibration_table, cache_path, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -623,6 +623,13 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_timing_cache_path") { + if (!option.second.empty()) { + timing_cache_path = option.second; + params.trt_timing_cache_path = timing_cache_path.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_path' should be a path string i.e. 'cache_folder/'.\n"); + } } else if (option.first == "trt_force_timing_cache") { if (option.second == "True" || option.second == "true") { params.trt_force_timing_cache = true; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index a7f0b7584a21..e828a7cee5ea 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include "core/session/onnxruntime_session_options_config_keys.h" @@ -100,36 +101,28 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options; Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - - const char* cudnn_conv_algo_search = "cudnn_conv_algo_search"; - const char* default_conv = "DEFAULT"; - const char* benchmarking = "EXHAUSTIVE"; - const char* heuristic = "HEURISTIC"; + std::vector option_keys, option_values; + // used to keep all option keys and value strings alive + std::list buffer; + buffer.emplace_back("cudnn_conv_algo_search"); + option_keys.push_back(buffer.back().c_str()); switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &cudnn_conv_algo_search, &benchmarking, 1)); + buffer.emplace_back("EXHAUSTIVE"); break; case 1: - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &cudnn_conv_algo_search, &heuristic, 1)); + buffer.emplace_back("HEURISTIC"); break; default: - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &cudnn_conv_algo_search, &default_conv, 1)); + buffer.emplace_back("DEFAULT"); break; } + option_values.push_back(buffer.back().c_str()); - const char* do_copy_in_default_stream = "do_copy_in_default_stream"; - if (performance_test_config.run_config.do_cuda_copy_in_separate_stream) { - const char* v = "1"; - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &do_copy_in_default_stream, &v, 1)); - } else { - const char* v = "0"; - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &do_copy_in_default_stream, &v, 1)); - } + buffer.emplace_back("do_copy_in_default_stream"); + option_keys.push_back(buffer.back().c_str()); + buffer.emplace_back(performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); + option_values.push_back(buffer.back().c_str()); #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -148,51 +141,34 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "[ERROR] [CUDA] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); } - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - auto key_p = key.c_str(); - auto value_p = value.c_str(); - Ort::ThrowOnError( - api.UpdateCUDAProviderOptions(cuda_options, &key_p, &value_p, 1)); + buffer.emplace_back(token.substr(0, pos)); + option_keys.push_back(buffer.back().c_str()); + buffer.emplace_back(token.substr(pos + 1)); + option_values.push_back(buffer.back().c_str()); } + Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, + option_keys.data(), option_values.data(), option_keys.size())); + if (!status.IsOK()) { + OrtAllocator* allocator; + char* options; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + Ort::ThrowOnError(api.GetCUDAProviderOptionsAsString(cuda_options, allocator, &options)); + ORT_THROW("[ERROR] [CUDA] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), + "\nSupported options are:\n", options); + } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #else ORT_THROW("CUDA is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT - int device_id = 0; - int trt_max_partition_iterations = 1000; - int trt_min_subgraph_size = 1; - size_t trt_max_workspace_size = 1 << 30; - bool trt_fp16_enable = false; - bool trt_int8_enable = false; - std::string trt_int8_calibration_table_name = ""; - bool trt_int8_use_native_calibration_table = false; - bool trt_dla_enable = false; - int trt_dla_core = 0; - bool trt_dump_subgraphs = false; - bool trt_engine_cache_enable = false; - std::string trt_engine_cache_path = ""; - bool trt_engine_decryption_enable = false; - std::string trt_engine_decryption_lib_path = ""; - bool trt_force_sequential_engine_build = false; - bool trt_context_memory_sharing_enable = false; - bool trt_layer_norm_fp32_fallback = false; - bool trt_timing_cache_enable = false; - bool trt_force_timing_cache = false; - bool trt_detailed_build_log = false; - bool trt_build_heuristics_enable = false; - bool trt_sparsity_enable = false; - int trt_builder_optimization_level = 3; - int trt_auxiliary_streams = -1; - std::string trt_tactic_sources = ""; - std::string trt_extra_plugin_lib_paths = ""; - std::string trt_profile_min_shapes = ""; - std::string trt_profile_max_shapes = ""; - std::string trt_profile_opt_shapes = ""; - bool trt_cuda_graph_enable = false; + const auto& api = Ort::GetApi(); + OrtTensorRTProviderOptionsV2* tensorrt_options; + Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); + std::vector option_keys, option_values; + // used to keep all option keys and value strings alive + std::list buffer; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -207,272 +183,31 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } auto pos = token.find("|"); if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + ORT_THROW( + "[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); } - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - if (key == "device_id") { - if (!value.empty()) { - device_id = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number.\n"); - } - } else if (key == "trt_max_partition_iterations") { - if (!value.empty()) { - trt_max_partition_iterations = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_partition_iterations' should be a number.\n"); - } - } else if (key == "trt_min_subgraph_size") { - if (!value.empty()) { - trt_min_subgraph_size = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_min_subgraph_size' should be a number.\n"); - } - } else if (key == "trt_max_workspace_size") { - if (!value.empty()) { - trt_max_workspace_size = std::stoull(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_max_workspace_size' should be a number.\n"); - } - } else if (key == "trt_fp16_enable") { - if (value == "true" || value == "True") { - trt_fp16_enable = true; - } else if (value == "false" || value == "False") { - trt_fp16_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_int8_enable") { - if (value == "true" || value == "True") { - trt_int8_enable = true; - } else if (value == "false" || value == "False") { - trt_int8_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_int8_calibration_table_name") { - if (!value.empty()) { - trt_int8_calibration_table_name = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_calibration_table_name' should be a non-empty string.\n"); - } - } else if (key == "trt_int8_use_native_calibration_table") { - if (value == "true" || value == "True") { - trt_int8_use_native_calibration_table = true; - } else if (value == "false" || value == "False") { - trt_int8_use_native_calibration_table = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_int8_use_native_calibration_table' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_dla_enable") { - if (value == "true" || value == "True") { - trt_dla_enable = true; - } else if (value == "false" || value == "False") { - trt_dla_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_dla_core") { - if (!value.empty()) { - trt_dla_core = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dla_core' should be a number.\n"); - } - } else if (key == "trt_dump_subgraphs") { - if (value == "true" || value == "True") { - trt_dump_subgraphs = true; - } else if (value == "false" || value == "False") { - trt_dump_subgraphs = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_subgraphs' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_engine_cache_enable") { - if (value == "true" || value == "True") { - trt_engine_cache_enable = true; - } else if (value == "false" || value == "False") { - trt_engine_cache_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_engine_cache_path") { - if (!value.empty()) { - trt_engine_cache_path = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_path' should be a non-empty string.\n"); - } - } else if (key == "trt_engine_decryption_enable") { - if (value == "true" || value == "True") { - trt_engine_decryption_enable = true; - } else if (value == "false" || value == "False") { - trt_engine_decryption_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_engine_decryption_lib_path") { - if (!value.empty()) { - trt_engine_decryption_lib_path = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_decryption_lib_path' should be a non-empty string.\n"); - } - } else if (key == "trt_force_sequential_engine_build") { - if (value == "true" || value == "True") { - trt_force_sequential_engine_build = true; - } else if (value == "false" || value == "False") { - trt_force_sequential_engine_build = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_context_memory_sharing_enable") { - if (value == "true" || value == "True") { - trt_context_memory_sharing_enable = true; - } else if (value == "false" || value == "False") { - trt_context_memory_sharing_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_context_memory_sharing_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_layer_norm_fp32_fallback") { - if (value == "true" || value == "True") { - trt_layer_norm_fp32_fallback = true; - } else if (value == "false" || value == "False") { - trt_layer_norm_fp32_fallback = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_layer_norm_fp32_fallback' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_timing_cache_enable") { - if (value == "true" || value == "True") { - trt_timing_cache_enable = true; - } else if (value == "false" || value == "False") { - trt_timing_cache_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_force_timing_cache") { - if (value == "true" || value == "True") { - trt_force_timing_cache = true; - } else if (value == "false" || value == "False") { - trt_force_timing_cache = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_timing_cache' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_detailed_build_log") { - if (value == "true" || value == "True") { - trt_detailed_build_log = true; - } else if (value == "false" || value == "False") { - trt_detailed_build_log = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_detailed_build_log' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_build_heuristics_enable") { - if (value == "true" || value == "True") { - trt_build_heuristics_enable = true; - } else if (value == "false" || value == "False") { - trt_build_heuristics_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_build_heuristics_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_sparsity_enable") { - if (value == "true" || value == "True") { - trt_sparsity_enable = true; - } else if (value == "false" || value == "False") { - trt_sparsity_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_sparsity_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else if (key == "trt_builder_optimization_level") { - if (!value.empty()) { - trt_builder_optimization_level = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_builder_optimization_level' should be a number and default to 2.\n"); - } - } else if (key == "trt_auxiliary_streams") { - if (!value.empty()) { - trt_auxiliary_streams = std::stoi(value); - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_auxiliary_streams' should be a number.\n"); - } - } else if (key == "trt_tactic_sources") { - if (!value.empty()) { - trt_tactic_sources = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_tactic_sources' should be a non-empty string.\n"); - } - } else if (key == "trt_extra_plugin_lib_paths") { - if (!value.empty()) { - trt_extra_plugin_lib_paths = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_extra_plugin_lib_paths' should be a non-empty string.\n"); - } - } else if (key == "trt_profile_min_shapes") { - if (!value.empty()) { - trt_profile_min_shapes = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_min_shapes' should be a non-empty string.\n"); - } - } else if (key == "trt_profile_max_shapes") { - if (!value.empty()) { - trt_profile_max_shapes = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_max_shapes' should be a non-empty string.\n"); - } - } else if (key == "trt_profile_opt_shapes") { - if (!value.empty()) { - trt_profile_opt_shapes = value; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a non-empty string.\n"); - } - } else if (key == "trt_cuda_graph_enable") { - if (value == "true" || value == "True") { - trt_cuda_graph_enable = true; - } else if (value == "false" || value == "False") { - trt_cuda_graph_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be a boolean i.e. true or false. Default value is false.\n"); - } - } else { - ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_timing_cache_enable', 'trt_force_timing_cache', 'trt_detailed_build_log', 'trt_build_heuristics_enable', 'trt_sparsity_enable', 'trt_builder_optimization_level', 'trt_auxiliary_streams', 'trt_tactic_sources', 'trt_extra_plugin_lib_paths', 'trt_profile_min_shapes', 'trt_profile_max_shapes', 'trt_profile_opt_shapes', 'trt_cuda_graph_enable'] \n"); - } + buffer.emplace_back(token.substr(0, pos)); + option_keys.push_back(buffer.back().c_str()); + buffer.emplace_back(token.substr(pos + 1)); + option_values.push_back(buffer.back().c_str()); + } + + Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, + option_keys.data(), option_values.data(), option_keys.size())); + if (!status.IsOK()) { + OrtAllocator* allocator; + char* options; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + Ort::ThrowOnError(api.GetTensorRTProviderOptionsAsString(tensorrt_options, allocator, &options)); + ORT_THROW("[ERROR] [TensorRT] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), + "\nSupported options are:\n", options); } - OrtTensorRTProviderOptionsV2 tensorrt_options; - tensorrt_options.device_id = device_id; - tensorrt_options.has_user_compute_stream = 0; - tensorrt_options.user_compute_stream = nullptr; - tensorrt_options.trt_max_partition_iterations = trt_max_partition_iterations; - tensorrt_options.trt_min_subgraph_size = trt_min_subgraph_size; - tensorrt_options.trt_max_workspace_size = trt_max_workspace_size; - tensorrt_options.trt_fp16_enable = trt_fp16_enable; - tensorrt_options.trt_int8_enable = trt_int8_enable; - tensorrt_options.trt_int8_calibration_table_name = trt_int8_calibration_table_name.c_str(); - tensorrt_options.trt_int8_use_native_calibration_table = trt_int8_use_native_calibration_table; - tensorrt_options.trt_dla_enable = trt_dla_enable; - tensorrt_options.trt_dla_core = trt_dla_core; - tensorrt_options.trt_dump_subgraphs = trt_dump_subgraphs; - tensorrt_options.trt_engine_cache_enable = trt_engine_cache_enable; - tensorrt_options.trt_engine_cache_path = trt_engine_cache_path.c_str(); - tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable; - tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str(); - tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build; - tensorrt_options.trt_context_memory_sharing_enable = trt_context_memory_sharing_enable; - tensorrt_options.trt_layer_norm_fp32_fallback = trt_layer_norm_fp32_fallback; - tensorrt_options.trt_timing_cache_enable = trt_timing_cache_enable; - tensorrt_options.trt_force_timing_cache = trt_force_timing_cache; - tensorrt_options.trt_detailed_build_log = trt_detailed_build_log; - tensorrt_options.trt_build_heuristics_enable = trt_build_heuristics_enable; - tensorrt_options.trt_sparsity_enable = trt_sparsity_enable; - tensorrt_options.trt_builder_optimization_level = trt_builder_optimization_level; - tensorrt_options.trt_auxiliary_streams = trt_auxiliary_streams; - tensorrt_options.trt_tactic_sources = trt_tactic_sources.c_str(); - tensorrt_options.trt_extra_plugin_lib_paths = trt_extra_plugin_lib_paths.c_str(); - tensorrt_options.trt_profile_min_shapes = trt_profile_min_shapes.c_str(); - tensorrt_options.trt_profile_max_shapes = trt_profile_max_shapes.c_str(); - tensorrt_options.trt_profile_opt_shapes = trt_profile_opt_shapes.c_str(); - tensorrt_options.trt_cuda_graph_enable = trt_cuda_graph_enable; - - session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); + + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); OrtCUDAProviderOptions cuda_options; - cuda_options.device_id = device_id; + cuda_options.device_id = tensorrt_options->device_id; cuda_options.cudnn_conv_algo_search = static_cast(performance_test_config.run_config.cudnn_conv_algo); cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index aa96e1533653..d9f917f6d187 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -590,6 +590,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { // uint64_t compilation_without_cache_ms, compilation_with_cache_ms; // First session is created with TRT EP with timing cache enabled + // Not specifying a trt_timing_cache_path will result in using the working directory params.trt_timing_cache_enable = 1; { // auto start = chrono::steady_clock::now(); From 28ad3ff799e163bf5c854db03e87086a61e14256 Mon Sep 17 00:00:00 2001 From: sophies927 <107952697+sophies927@users.noreply.github.com> Date: Fri, 27 Oct 2023 10:57:28 -0700 Subject: [PATCH 30/36] Fix stale bot issue (#18064) ### Description Previously used GitHub stale app is now deprecated, so I deleted that file and added a new GitHub Actions workflow to automatically apply the stale label to inactive issues. ### Motivation and Context --- .github/stale.yml | 22 ---------------------- .github/workflows/stale.yml | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 22 deletions(-) delete mode 100644 .github/stale.yml create mode 100644 .github/workflows/stale.yml diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index d89f0cdd91e5..000000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,22 +0,0 @@ -# Number of days of inactivity before an issue becomes stale -daysUntilStale: 30 - -# Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 - -# Issues with these labels will never be considered stale -exemptLabels: - - contributions welcome - - feature request - - regression - -# Label to use when marking an issue as stale -staleLabel: stale - -# Comment to post when marking an issue as stale. Set to `false` to disable -markComment: > - This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details. - -# Comment to post when closing a stale issue. Set to `false` to disable -closeComment: > - This issue has been automatically closed due to inactivity. Please reactivate if further support is needed. diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 000000000000..67d8550d4420 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,34 @@ +name: Close stale issues +on: + # Allows you to dictate when you want this workflow to run using cron syntax (times in UTC) + schedule: + - cron: "0 15 * * *" + # Allows you to run this workflow manually from the Actions tab + # workflow_dispatch: + +jobs: + close-stale-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v4.1.1 + with: + # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale + exempt-issue-labels: contributions welcome, feature request, regression + # Number of days without activity before the actions/stale action labels an issue + days-before-issue-stale: 30 + # Number of days without activity before the actions/stale action closes an issue + days-before-issue-close: 7 + # Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale + stale-issue-label: "stale" + # Comment that you want to add to issues that are labeled by the actions/stale action + stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details." + # Comment that you want to add to issues that are closed by the actions/stale action + close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed." + # If you never want this action to label PRs, set this value to -1 + days-before-pr-stale: -1 + # If you never want this action to close PRs, set this value to -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} From d9695dea6df183f4290d6867f8fae0519f542d83 Mon Sep 17 00:00:00 2001 From: zesongw Date: Sat, 28 Oct 2023 04:57:01 +0800 Subject: [PATCH 31/36] [WebNN EP] Remove Conv initializer constraint for GPU (#18129) ### Description WebNN can now handle Conv with filter as input . ### Motivation and Context Support more models with WebNN. --- .../webnn/builders/impl/conv_op_builder.cc | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 1e0af51567ca..af3293dd3d92 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -49,9 +49,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, NodeAttrHelper helper(node); const auto group = helper.Get("group", static_cast(1)); const auto& input_defs = node.InputDefs(); - const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - const auto& weight_shape = weight_tensor.dims(); - + std::vector weight_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); options.set("strides", emscripten::val::array(strides)); options.set("dilations", emscripten::val::array(dilations)); options.set("groups", group); @@ -278,25 +277,28 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d is supported."; + // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. + // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 + if (device_type == WebnnDeviceType::CPU) { + if (Contains(initializers, weight_name)) { + const auto& tensor = *initializers.at(weight_name); + if (tensor.dims().size() != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() + << " Only conv 2d is supported."; + return false; + } + } else { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; return false; } - } else { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; } - return true; } From 8daabf3f15329e58d28557c9fdfd1254836b6573 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:09:07 -0700 Subject: [PATCH 32/36] Tune min version supporint custom op ComputeV2 (#18134) Set min version supporting custom_op::ComputeV2 to 16, since the feature has been released since ort 1.16. Co-authored-by: Randy Shuai --- onnxruntime/core/session/custom_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 5f1d5036e831..041250adc3fc 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -28,7 +28,7 @@ static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -static constexpr uint32_t min_ort_version_with_compute_v2_support = 17; +static constexpr uint32_t min_ort_version_with_compute_v2_support = 16; static constexpr uint32_t min_ort_version_with_shape_inference = 17; #endif From 24f9c1afe3a1972cfacd9b4bcd41f2227d369bf0 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Sat, 28 Oct 2023 00:44:02 -0700 Subject: [PATCH 33/36] Distributed Expand (#18126) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements DistributedExpand for llama 2. Representative Examples of DistributedExpand: - [shard on non-expanded axis] `input tensor (shape=[8, 1], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[8, 2] -> output tensor (shape=[8, 2], spec=S[0]R, device_mesh=[0,1])` - [sharding expanded axis is invalid since it must have dim=1 and axis with dim=1 cannot be sharded] `input tensor (shape=[1, 8], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[2, 8] -> output tensor (shape=[2, 8], spec=S[0]R, device_mesh=[0,1])` From those examples, we observe a few important behaviors. - The output sharding spec is always the same to the input sharding spec. - Expanding always happen on axis with dimension=1. Otherwise, it will violate the broadcasting rule. - No communication is needed since all computation can happen locally. Let's consider the first example again. If you put the first half tensor (shape: [4, 1]) on device 0 and the second half (shape: [4, 1]) on device 1, then `Expand` it with target shape [4, 2] , these two local tensors (shape: [4, 2]) are exactly the same as the one described by output sharding spec. Algorithm: - Compute logical (i.e., unsharded) shapes of input and output. - Compute sharded output shape from logical output. - Call Expand to broadcast local input to sharded output shape. How to review? - Start with [changes in onnxruntime_test_distributed.py](https://github.com/microsoft/onnxruntime/pull/18126/commits/ea33392f375afd8e95d29bd5b1a403192ed3bebc). Those tests are good examples for using this op. - [Read expand.h/expand.cc](https://github.com/microsoft/onnxruntime/pull/18126/commits/e4c49987f5a09e19527248adcc197b7d4a695636). Theose changes are for exposing functionalities in Expand to DistributedExpand. - Read distributed_expand.h/distributed_expand.cc. It follows the algorithm described above. The commit https://github.com/microsoft/onnxruntime/pull/18126/commits/68ac301bbaff44d08168ac9049161a4d428b3c3d first sketches the definition of DistributedExpand. The next commit https://github.com/microsoft/onnxruntime/pull/18126/commits/0eb9330c3ba836911932444caca7fec0cbdad222 adds real implementation. --- cmake/onnxruntime_providers_cuda.cmake | 1 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../cuda/collective/distributed_expand.cc | 110 +++++++++++ .../cuda/collective/distributed_expand.h | 35 ++++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 + .../core/graph/contrib_ops/collective_defs.cc | 37 ++++ .../core/providers/cuda/tensor/expand.cc | 80 ++++++++ .../core/providers/cuda/tensor/expand.h | 13 ++ .../python/onnxruntime_test_distributed.py | 185 ++++++++++++++++++ 9 files changed, 470 insertions(+) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_expand.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 02b17ee324f4..043789c36c32 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -39,6 +39,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index ec021a1550d6..6ccf063c7129 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -108,6 +108,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc new file mode 100644 index 000000000000..3cfa3ab95934 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_expand.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/expand.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {} + +template +Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape) + .IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h new file mode 100644 index 000000000000..dedb1bdc5aa3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedExpand final : public DistributedKernel { + public: + explicit DistributedExpand(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index e6a216795c10..2618fe4a238b 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -170,6 +170,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); #endif template <> @@ -344,6 +348,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 8082b8c010e9..070df487a264 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -236,6 +236,43 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index e9634df20584..806ecfa1aab1 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -142,6 +142,86 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { input_strides); } +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor) { + TensorShape output_shape = output_tensor->Shape(); + +#ifdef ENABLE_STRIDED_TENSORS + // Strided output. + if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) { + gsl::span input_strides = input_data_tensor->Strides(); + TensorShapeVector output_strides = + ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape); + output_tensor->SetShapeAndStrides(output_shape, output_strides); + return Status::OK(); + } +#endif + + auto output_dims = output_shape.AsShapeVector(); + auto input_dims = input_data_tensor->Shape().AsShapeVector(); + + CalcEffectiveDims(input_dims, output_dims); + int rank = gsl::narrow_cast(output_dims.size()); + + TensorPitches original_input_strides(input_dims); + TensorPitches original_output_strides(output_dims); + + TArray input_strides(rank); + for (auto i = 0; i < rank; i++) { + input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i]; + } + + TArray output_strides(rank); + for (auto i = 0; i < rank; i++) { + output_strides[i] = fast_divmod(static_cast(original_output_strides[i])); + } + + return ExpandImpl( + cuda_kernel->Stream(ctx), + input_data_tensor->DataType()->Size(), + gsl::narrow_cast(output_shape.Size()), + gsl::narrow_cast(input_data_tensor->Shape().Size()), + input_data_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + output_strides, + input_strides); +} + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor) { + // new shape to be expanded to + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + + ORT_ENFORCE( + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape) + .IsOK()); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc); + + // Only assign output values when output tensor is non-empty + // because empty tensor doesn't own any data. + if (output_shape.Size() > 0) { + ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK()); + } + + return output_tensor; +} + #ifdef ENABLE_STRIDED_TENSORS #define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0) #else diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index 4cf4c14e6105..a0b12790017f 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -20,5 +20,18 @@ Status ComputeOutputShape( const TensorShape& rhs_shape, TensorShape& out_shape); +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor); + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index 2acca4a8f22a..e0fb3979a9f5 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -685,6 +685,191 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) ) +class TestDistributedExpand(unittest.TestCase): + def _check_distributed_expand( + self, + shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + input_device_meshs: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshs: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) + assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) + assert len(input_device_meshs) == len(input_shard_specs) + assert len(output_device_meshs) == len(output_shard_specs) + + input_device_mesh_shapes = [] + input_device_mesh_elements = [] + for device_mesh in input_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + input_device_mesh_shapes.append(device_mesh_shape) + input_device_mesh_elements.append(device_mesh_element) + + output_device_mesh_shapes = [] + output_device_mesh_elements = [] + for device_mesh in output_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + output_device_mesh_shapes.append(device_mesh_shape) + output_device_mesh_elements.append(device_mesh_element) + + @onnxscript.script() + def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): + return MICROSOFT_OPSET.DistributedExpand( + data_tensor, + shape_tensor, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + shape_tensor = np.array( + target_shape, + dtype=np.int64, + ) + + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + assert "S" not in input_shard_specs[1], "Shape should not be sharded." + + expected = data_tensor * np.ones(shape_tensor) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + + onnx_model = distributed_expand_instance.to_model_proto( + input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data_tensor, + "shape_tensor": shape_tensor, + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_expand_sharded_on_expanded_axis(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RS, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 8, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RS, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 8, + 8, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_expand_replicated_on_expanded_axis(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RR, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 1, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_expand_with_pass_through_sharding_spec(self): + # data: shape=[8,1], spec=(SR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(SR, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 1, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=( + "S[0]R", + "R", + ), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_expand_in_tiny_llama(self): + # data: shape=[2,4,256,4], spec=(RSRR, [0,1]) + # shape: shape=[4], spec=(R, [0,1,2,3]), value=[2,4,256,4] + # output: shape=[2,4,256,4], spec=(RSRR, [0,1]) + self._check_distributed_expand( + shape=( + 2, + 4, + 256, + 4, + ), + target_shape=( + 2, + 4, + 256, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. From 0e34100484594508e810ce43ce242112edac6444 Mon Sep 17 00:00:00 2001 From: snadampal <87143774+snadampal@users.noreply.github.com> Date: Sun, 29 Oct 2023 11:43:12 -0500 Subject: [PATCH 34/36] create memory descriptors based on the tensor dimensions (#15848) Arm Compute Library(ACL)backend requires explicit memory format tag iniatilization to decide wether the tensor can be computed with the ACL kernels. Hence, the src, weights and dst memroy descriptor format is set based on the tensor dimensions instead of using the format::any tag. ### Description Arm Compute Library(ACL)backend requires explicit memory format tag iniatilization to decide wether the tensor can be computed with the ACL kernels. Hence, the src, weights and dst memroy descriptor format is set based on the tensor dimensions instead of using the format::any tag. ### Motivation and Context The change enables ACL kernels for DNNL matmul ops on aarch64 platform. --- .../providers/dnnl/subgraph/dnnl_matmul.cc | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc index c3eab9dd8e55..54528011850b 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc @@ -12,6 +12,25 @@ namespace onnxruntime { namespace ort_dnnl { +inline static dnnl::memory::format_tag get_default_format(const dnnl::memory::dims& tensor_dims) { + switch (tensor_dims.size()) { + case 1: + return dnnl::memory::format_tag::a; + case 2: + return dnnl::memory::format_tag::ab; + case 3: + return dnnl::memory::format_tag::abc; + case 4: + return dnnl::memory::format_tag::abcd; + case 5: + return dnnl::memory::format_tag::abcde; + case 6: + return dnnl::memory::format_tag::abcdef; + default: + return dnnl::memory::format_tag::undef; + } +} + DnnlMatMul::DnnlMatMul() {} // This handles ONNX defined "MatMul" as well as two other variations of MatMul @@ -139,14 +158,14 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { if (transA || transBatchA) { src_md = transposedA_md; } else { - src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any); + src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), get_default_format(src_dims)); } dnnl::memory::desc weights_md; if (transB || transBatchB) { weights_md = transposedB_md; } else { - weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any); + weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), get_default_format(weights_dims)); } auto output_shape = src_dims; @@ -241,7 +260,7 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { attr.set_scales_mask(DNNL_ARG_SRC, 0); } - auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any); + auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), get_default_format(output_shape)); auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, weights_md, dst_md, attr); From 8ebdd3bbcafc7e195e41b2aa3881d258d365f608 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Sun, 29 Oct 2023 19:26:12 -0700 Subject: [PATCH 35/36] Fix regression in perf test runner (#18139) --- onnxruntime/test/perftest/ort_test_session.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index e828a7cee5ea..41a1eafebbb5 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -121,7 +121,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device buffer.emplace_back("do_copy_in_default_stream"); option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); + buffer.emplace_back(!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); option_values.push_back(buffer.back().c_str()); #ifdef _MSC_VER From 436056dcd7c0533b52325fd052ba9fd219aacd66 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 30 Oct 2023 15:41:07 +0800 Subject: [PATCH 36/36] Revert "Disable dml stage in windows GPU pipeline temporarily. (#18034)" (#18150) This reverts commit 99b8dcaae2ac033b10880bc5bc7b0e89b37fe466. ### Description ### Motivation and Context Restore the dml stage in windows GPU pipeline. Agent issue is solved by adding Feature.DisableGpuDriver in pool properties. --- .../azure-pipelines/win-gpu-ci-pipeline.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index ae2a4b4cead3..2ba4b7bea371 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -70,6 +70,23 @@ stages: MachinePool: onnxruntime-Win2022-GPU-T4 isTraining: true +- stage: dml + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env.bat + buildArch: x64 + additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: DML + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-dml-A10 + - stage: kernelDocumentation dependsOn: [] jobs: